[feature] persistent_workers; new args accumulate_grad_batches strategy

This commit is contained in:
Yiqing-Zhou 2023-05-06 19:39:24 +08:00
parent 45fa065530
commit 9b7f9b9d60
1 changed files with 16 additions and 3 deletions

View File

@ -8,7 +8,6 @@ import datasets
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from pytorch_lightning import strategies
from torch.utils.data import ConcatDataset, DataLoader from torch.utils.data import ConcatDataset, DataLoader
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
@ -131,12 +130,24 @@ def parse_args():
help="Batch size of validating", help="Batch size of validating",
default=16, default=16,
) )
parser.add_argument(
"--accumulate_grad_batches",
type=int,
help="Accumulate grad batches",
default=32,
)
parser.add_argument( parser.add_argument(
"--num_proc", "--num_proc",
type=str, type=str,
help="Number of data processes", help="Number of data processes",
default=16, default=16,
) )
parser.add_argument(
"--strategy",
type=str,
help="Name of pytorch lightning distribution strategy",
default='fsdp',
)
parser.add_argument( parser.add_argument(
"--resume_from_ckpt_path", "--resume_from_ckpt_path",
type=str, type=str,
@ -256,6 +267,7 @@ if __name__ == '__main__':
batch_size=args.train_batch_size, batch_size=args.train_batch_size,
num_workers=args.num_proc, num_workers=args.num_proc,
collate_fn=DefaultDataCollator(), collate_fn=DefaultDataCollator(),
persistent_workers=True,
shuffle=True, shuffle=True,
) )
val_dataloader = DataLoader( val_dataloader = DataLoader(
@ -263,6 +275,7 @@ if __name__ == '__main__':
batch_size=args.val_batch_size, batch_size=args.val_batch_size,
num_workers=args.num_proc, num_workers=args.num_proc,
collate_fn=DefaultDataCollator(), collate_fn=DefaultDataCollator(),
persistent_workers=True,
) )
# trainer # trainer
@ -277,8 +290,8 @@ if __name__ == '__main__':
accelerator='gpu', accelerator='gpu',
precision=precision, precision=precision,
log_every_n_steps=5, log_every_n_steps=5,
accumulate_grad_batches=32, accumulate_grad_batches=args.accumulate_grad_batches,
strategy=strategies.DeepSpeedStrategy(stage=3), strategy=args.strategy,
) )
lit_trainer.fit( lit_trainer.fit(
lit_module, lit_module,