From 3f92bbbaa2109e6e2a1ebeaac99ca0bcc061c31d Mon Sep 17 00:00:00 2001 From: Yiqing-Zhou Date: Tue, 9 May 2023 00:02:29 +0800 Subject: [PATCH] [feature] new args learning_rate max_epochs --- lit_module.py | 12 ++++++++++-- lit_train.py | 17 ++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/lit_module.py b/lit_module.py index 789c58e..f2390dd 100644 --- a/lit_module.py +++ b/lit_module.py @@ -9,10 +9,16 @@ from utils import init_model class LitModule(pl.LightningModule): - def __init__(self, model_name: str, use_tril_attention_mask: str = False): + def __init__( + self, + model_name: str, + learning_rate: float = 0.0001, + use_tril_attention_mask: str = False, + ): super().__init__() self.save_hyperparameters() self.llm = self.register_core_module(init_model(model_name)) + self.learning_rate = learning_rate self.use_tril_attention_mask = use_tril_attention_mask self.metric_loss = torchmetrics.MeanMetric() self.metric_accuracy = torchmetrics.Accuracy( @@ -62,7 +68,9 @@ class LitModule(pl.LightningModule): self.log('accuracy', self.metric_accuracy, rank_zero_only=True) def configure_optimizers(self): - optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001) + optimizer = torch.optim.AdamW( + self.trainer.model.parameters(), lr=self.learning_rate + ) return optimizer def configure_callbacks(self): diff --git a/lit_train.py b/lit_train.py index 3e16cf0..cfead59 100644 --- a/lit_train.py +++ b/lit_train.py @@ -80,6 +80,12 @@ def parse_args(): help="Name of or path to model", default='gpt2', ) + parser.add_argument( + "--learning_rate", + type=float, + help="Learning rate", + default=0.0001, + ) parser.add_argument( "--use_tril_attention_mask", help="Use tril attention mask during training", @@ -124,6 +130,12 @@ def parse_args(): help="Number of data processes", default=16, ) + parser.add_argument( + "--max_epochs", + type=int, + help="Max epochs", + default=None, + ) parser.add_argument( "--strategy", type=str, @@ -155,7 +167,9 @@ if __name__ == '__main__': set_seed(args.seed) # lightning module - lit_module = LitModule(args.model_name, args.use_tril_attention_mask) + lit_module = LitModule( + args.model_name, args.learning_rate, args.use_tril_attention_mask + ) # datasets tokenizer = load_tokenizer(args.tokenizer_name_or_path) @@ -204,6 +218,7 @@ if __name__ == '__main__': log_every_n_steps=5, accumulate_grad_batches=args.accumulate_grad_batches, strategy=args.strategy, + max_epochs=args.max_epochs, ) lit_trainer.fit( lit_module,