donot use auto optimizer.

This commit is contained in:
Colin 2025-03-13 14:28:53 +08:00
parent 990e27ba15
commit f411b1cc5e
1 changed files with 32 additions and 0 deletions

View File

@ -172,6 +172,7 @@ class LightModule(pl.LightningModule):
mconf = conf.model_config
use_tril_attention_mask = conf.use_tril_attention_mask
super().__init__()
self.automatic_optimization = False
self.save_hyperparameters()
if mconf == None:
mconf = ModelConfig()
@ -205,6 +206,19 @@ class LightModule(pl.LightningModule):
if self.use_tril_attention_mask:
batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
outputs, loss = self.llm(**batch)
# optimizer = self.optimizers()
# for o in optimizer:
# o.zero_grad()
# self.manual_backward(loss)
# for o in optimizer:
# o.step()
optimizer = self.optimizers()
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
self.log("train_loss", loss, rank_zero_only=True)
return loss
@ -232,6 +246,24 @@ class LightModule(pl.LightningModule):
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
return optimizer
# p = self.trainer.model.named_parameters()
# p0 = []
# p1 = []
# p2 = []
# for name, param in p:
# if "h.0" in name or ".wte" in name:
# p0.append(param)
# continue
# if "h.1" in name:
# p1.append(param)
# continue
# p2.append(param)
# optimizer0 = torch.optim.AdamW(p0, lr=self.learning_rate / 1.5)
# optimizer1 = torch.optim.AdamW(p1, lr=self.learning_rate)
# optimizer2 = torch.optim.AdamW(p2, lr=self.learning_rate * 1.5)
# return [optimizer0, optimizer1, optimizer2]
def configure_callbacks(self):
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="accuracy",