donot use auto optimizer.
This commit is contained in:
parent
990e27ba15
commit
f411b1cc5e
|
@ -172,6 +172,7 @@ class LightModule(pl.LightningModule):
|
||||||
mconf = conf.model_config
|
mconf = conf.model_config
|
||||||
use_tril_attention_mask = conf.use_tril_attention_mask
|
use_tril_attention_mask = conf.use_tril_attention_mask
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.automatic_optimization = False
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
if mconf == None:
|
if mconf == None:
|
||||||
mconf = ModelConfig()
|
mconf = ModelConfig()
|
||||||
|
@ -205,6 +206,19 @@ class LightModule(pl.LightningModule):
|
||||||
if self.use_tril_attention_mask:
|
if self.use_tril_attention_mask:
|
||||||
batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
|
batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
|
||||||
outputs, loss = self.llm(**batch)
|
outputs, loss = self.llm(**batch)
|
||||||
|
|
||||||
|
# optimizer = self.optimizers()
|
||||||
|
# for o in optimizer:
|
||||||
|
# o.zero_grad()
|
||||||
|
# self.manual_backward(loss)
|
||||||
|
# for o in optimizer:
|
||||||
|
# o.step()
|
||||||
|
|
||||||
|
optimizer = self.optimizers()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
self.manual_backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
self.log("train_loss", loss, rank_zero_only=True)
|
self.log("train_loss", loss, rank_zero_only=True)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -232,6 +246,24 @@ class LightModule(pl.LightningModule):
|
||||||
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
|
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
# p = self.trainer.model.named_parameters()
|
||||||
|
# p0 = []
|
||||||
|
# p1 = []
|
||||||
|
# p2 = []
|
||||||
|
# for name, param in p:
|
||||||
|
# if "h.0" in name or ".wte" in name:
|
||||||
|
# p0.append(param)
|
||||||
|
# continue
|
||||||
|
# if "h.1" in name:
|
||||||
|
# p1.append(param)
|
||||||
|
# continue
|
||||||
|
# p2.append(param)
|
||||||
|
|
||||||
|
# optimizer0 = torch.optim.AdamW(p0, lr=self.learning_rate / 1.5)
|
||||||
|
# optimizer1 = torch.optim.AdamW(p1, lr=self.learning_rate)
|
||||||
|
# optimizer2 = torch.optim.AdamW(p2, lr=self.learning_rate * 1.5)
|
||||||
|
# return [optimizer0, optimizer1, optimizer2]
|
||||||
|
|
||||||
def configure_callbacks(self):
|
def configure_callbacks(self):
|
||||||
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
||||||
monitor="accuracy",
|
monitor="accuracy",
|
||||||
|
|
Loading…
Reference in New Issue