[feature] new args learning_rate max_epochs
This commit is contained in:
		
							parent
							
								
									dc2941f790
								
							
						
					
					
						commit
						3f92bbbaa2
					
				|  | @ -9,10 +9,16 @@ from utils import init_model | |||
| 
 | ||||
| 
 | ||||
| class LitModule(pl.LightningModule): | ||||
|     def __init__(self, model_name: str, use_tril_attention_mask: str = False): | ||||
|     def __init__( | ||||
|         self, | ||||
|         model_name: str, | ||||
|         learning_rate: float = 0.0001, | ||||
|         use_tril_attention_mask: str = False, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.save_hyperparameters() | ||||
|         self.llm = self.register_core_module(init_model(model_name)) | ||||
|         self.learning_rate = learning_rate | ||||
|         self.use_tril_attention_mask = use_tril_attention_mask | ||||
|         self.metric_loss = torchmetrics.MeanMetric() | ||||
|         self.metric_accuracy = torchmetrics.Accuracy( | ||||
|  | @ -62,7 +68,9 @@ class LitModule(pl.LightningModule): | |||
|         self.log('accuracy', self.metric_accuracy, rank_zero_only=True) | ||||
| 
 | ||||
|     def configure_optimizers(self): | ||||
|         optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001) | ||||
|         optimizer = torch.optim.AdamW( | ||||
|             self.trainer.model.parameters(), lr=self.learning_rate | ||||
|         ) | ||||
|         return optimizer | ||||
| 
 | ||||
|     def configure_callbacks(self): | ||||
|  |  | |||
							
								
								
									
										17
									
								
								lit_train.py
								
								
								
								
							
							
						
						
									
										17
									
								
								lit_train.py
								
								
								
								
							|  | @ -80,6 +80,12 @@ def parse_args(): | |||
|         help="Name of or path to model", | ||||
|         default='gpt2', | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--learning_rate", | ||||
|         type=float, | ||||
|         help="Learning rate", | ||||
|         default=0.0001, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--use_tril_attention_mask", | ||||
|         help="Use tril attention mask during training", | ||||
|  | @ -124,6 +130,12 @@ def parse_args(): | |||
|         help="Number of data processes", | ||||
|         default=16, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max_epochs", | ||||
|         type=int, | ||||
|         help="Max epochs", | ||||
|         default=None, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--strategy", | ||||
|         type=str, | ||||
|  | @ -155,7 +167,9 @@ if __name__ == '__main__': | |||
|     set_seed(args.seed) | ||||
| 
 | ||||
|     # lightning module | ||||
|     lit_module = LitModule(args.model_name, args.use_tril_attention_mask) | ||||
|     lit_module = LitModule( | ||||
|         args.model_name, args.learning_rate, args.use_tril_attention_mask | ||||
|     ) | ||||
| 
 | ||||
|     # datasets | ||||
|     tokenizer = load_tokenizer(args.tokenizer_name_or_path) | ||||
|  | @ -204,6 +218,7 @@ if __name__ == '__main__': | |||
|         log_every_n_steps=5, | ||||
|         accumulate_grad_batches=args.accumulate_grad_batches, | ||||
|         strategy=args.strategy, | ||||
|         max_epochs=args.max_epochs, | ||||
|     ) | ||||
|     lit_trainer.fit( | ||||
|         lit_module, | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue