[feature] new args learning_rate max_epochs
This commit is contained in:
parent
dc2941f790
commit
3f92bbbaa2
|
@ -9,10 +9,16 @@ from utils import init_model
|
||||||
|
|
||||||
|
|
||||||
class LitModule(pl.LightningModule):
|
class LitModule(pl.LightningModule):
|
||||||
def __init__(self, model_name: str, use_tril_attention_mask: str = False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
learning_rate: float = 0.0001,
|
||||||
|
use_tril_attention_mask: str = False,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
self.llm = self.register_core_module(init_model(model_name))
|
self.llm = self.register_core_module(init_model(model_name))
|
||||||
|
self.learning_rate = learning_rate
|
||||||
self.use_tril_attention_mask = use_tril_attention_mask
|
self.use_tril_attention_mask = use_tril_attention_mask
|
||||||
self.metric_loss = torchmetrics.MeanMetric()
|
self.metric_loss = torchmetrics.MeanMetric()
|
||||||
self.metric_accuracy = torchmetrics.Accuracy(
|
self.metric_accuracy = torchmetrics.Accuracy(
|
||||||
|
@ -62,7 +68,9 @@ class LitModule(pl.LightningModule):
|
||||||
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
self.log('accuracy', self.metric_accuracy, rank_zero_only=True)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=0.0001)
|
optimizer = torch.optim.AdamW(
|
||||||
|
self.trainer.model.parameters(), lr=self.learning_rate
|
||||||
|
)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
def configure_callbacks(self):
|
def configure_callbacks(self):
|
||||||
|
|
17
lit_train.py
17
lit_train.py
|
@ -80,6 +80,12 @@ def parse_args():
|
||||||
help="Name of or path to model",
|
help="Name of or path to model",
|
||||||
default='gpt2',
|
default='gpt2',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
help="Learning rate",
|
||||||
|
default=0.0001,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_tril_attention_mask",
|
"--use_tril_attention_mask",
|
||||||
help="Use tril attention mask during training",
|
help="Use tril attention mask during training",
|
||||||
|
@ -124,6 +130,12 @@ def parse_args():
|
||||||
help="Number of data processes",
|
help="Number of data processes",
|
||||||
default=16,
|
default=16,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_epochs",
|
||||||
|
type=int,
|
||||||
|
help="Max epochs",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--strategy",
|
"--strategy",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -155,7 +167,9 @@ if __name__ == '__main__':
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
# lightning module
|
# lightning module
|
||||||
lit_module = LitModule(args.model_name, args.use_tril_attention_mask)
|
lit_module = LitModule(
|
||||||
|
args.model_name, args.learning_rate, args.use_tril_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
# datasets
|
# datasets
|
||||||
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
||||||
|
@ -204,6 +218,7 @@ if __name__ == '__main__':
|
||||||
log_every_n_steps=5,
|
log_every_n_steps=5,
|
||||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||||
strategy=args.strategy,
|
strategy=args.strategy,
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
)
|
)
|
||||||
lit_trainer.fit(
|
lit_trainer.fit(
|
||||||
lit_module,
|
lit_module,
|
||||||
|
|
Loading…
Reference in New Issue