gpt-pretrain/lit_module.py

106 lines
3.8 KiB
Python
Raw Normal View History

2023-05-07 13:01:02 +08:00
from functools import cache
from typing import Dict, Optional
import pytorch_lightning as pl
import torch
import torchmetrics
from utils import init_model
2024-02-24 13:40:39 +08:00
from custom_models.gpt2.modeling_gpt2 import GPT2LMHeadModel
2023-05-07 13:01:02 +08:00
2024-02-24 13:44:22 +08:00
from transformers import AutoConfig
2023-05-07 13:01:02 +08:00
class LitModule(pl.LightningModule):
def __init__(
self,
model_name: str,
2024-02-24 13:40:39 +08:00
path: str = "",
learning_rate: float = 0.0001,
use_tril_attention_mask: str = False,
):
2023-05-07 13:01:02 +08:00
super().__init__()
self.save_hyperparameters()
2024-02-24 13:44:22 +08:00
if path != "":
2024-02-24 13:40:39 +08:00
config = AutoConfig.for_model(model_type=model_name)
model = GPT2LMHeadModel(config)
model = model.from_pretrained(path)
self.llm = self.register_core_module(model)
else:
self.llm = self.register_core_module(init_model(model_name))
self.learning_rate = learning_rate
2023-05-07 13:01:02 +08:00
self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric()
self.metric_accuracy = torchmetrics.Accuracy(
2024-02-24 13:44:22 +08:00
task="multiclass",
2023-05-07 13:01:02 +08:00
num_classes=self.llm.config.vocab_size,
)
@cache
2023-05-28 20:02:56 +08:00
def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor:
2023-05-07 13:01:02 +08:00
matrix = torch.ones(block_size, block_size).tril()
if batch_size is not None:
matrix = matrix.repeat(batch_size, 1, 1)
return matrix
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
2024-02-24 13:44:22 +08:00
object.__setattr__(self, "__core_module__", module)
2023-05-07 13:01:02 +08:00
return module
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
2024-02-24 13:44:22 +08:00
batch_size, block_size = batch["input_ids"].shape
2023-05-07 13:01:02 +08:00
if self.use_tril_attention_mask:
2024-02-24 13:44:22 +08:00
batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
2023-05-07 13:01:02 +08:00
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
2024-02-24 13:44:22 +08:00
self.log("train_loss", loss, rank_zero_only=True)
2023-05-07 13:01:02 +08:00
return loss
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
outputs = self.llm(**batch, return_dict=True)
loss = outputs.loss
logits = outputs.logits[..., :-1, :]
2024-02-24 13:44:22 +08:00
labels = batch["labels"][..., 1:]
2023-05-07 13:01:02 +08:00
self.metric_loss.update(loss)
label_mask = labels != -100
self.metric_accuracy.update(logits[label_mask], labels[label_mask])
def on_validation_epoch_end(self) -> None:
2024-02-24 13:44:22 +08:00
self.log("val_loss", self.metric_loss, rank_zero_only=True)
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
2023-05-07 13:01:02 +08:00
def configure_optimizers(self):
strategy = self.trainer.strategy
if isinstance(strategy, pl.strategies.DeepSpeedStrategy):
assert "optimizer" not in strategy.config
zero_config = strategy.config.get("zero_optimization")
if zero_config is not None:
if "offload_optimizer" in zero_config:
import deepspeed
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(
self.trainer.model.parameters(), lr=self.learning_rate
)
return optimizer
2023-05-28 20:02:56 +08:00
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
2023-05-07 13:01:02 +08:00
return optimizer
def configure_callbacks(self):
checkpoint_callback = pl.callbacks.ModelCheckpoint(
2024-02-24 13:44:22 +08:00
monitor="accuracy",
mode="max",
filename="{epoch:02d}-{accuracy:.4f}",
2023-05-07 13:01:02 +08:00
)
early_stop_callback = pl.callbacks.EarlyStopping(
2024-02-24 13:44:22 +08:00
monitor="accuracy",
2023-05-07 13:01:02 +08:00
min_delta=0.001,
patience=3,
2024-02-24 13:44:22 +08:00
mode="max",
2023-05-07 13:01:02 +08:00
stopping_threshold=1,
)
return [checkpoint_callback, early_stop_callback]