2024-02-25 20:20:32 +08:00
|
|
|
from functools import cache
|
|
|
|
from typing import Dict, Optional
|
|
|
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
|
|
|
import torchmetrics
|
|
|
|
|
|
|
|
from modeling_wit import QWenLMHeadModel
|
2024-03-14 11:40:26 +08:00
|
|
|
from wit.configuration import ModelConfig
|
2024-02-25 20:20:32 +08:00
|
|
|
|
|
|
|
from transformers import AutoConfig
|
|
|
|
|
|
|
|
|
|
|
|
class LitModule(pl.LightningModule):
|
|
|
|
def __init__(
|
|
|
|
self,
|
2024-03-14 11:40:26 +08:00
|
|
|
pretrained_model_dir: str = None,
|
2024-02-25 20:20:32 +08:00
|
|
|
learning_rate: float = 0.0001,
|
2024-03-14 11:40:26 +08:00
|
|
|
config: ModelConfig = None,
|
2024-02-25 20:20:32 +08:00
|
|
|
use_tril_attention_mask: str = False,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.save_hyperparameters()
|
2024-03-14 11:40:26 +08:00
|
|
|
if config == None:
|
|
|
|
config = ModelConfig()
|
2024-02-25 20:20:32 +08:00
|
|
|
model = QWenLMHeadModel(config)
|
2024-03-14 11:40:26 +08:00
|
|
|
if pretrained_model_dir != None:
|
2024-03-14 13:28:40 +08:00
|
|
|
from modelscope import snapshot_download
|
|
|
|
|
2024-03-14 11:40:26 +08:00
|
|
|
model = model.from_pretrained(snapshot_download(pretrained_model_dir))
|
2024-02-25 20:20:32 +08:00
|
|
|
self.llm = self.register_core_module(model)
|
|
|
|
self.learning_rate = learning_rate
|
|
|
|
self.use_tril_attention_mask = use_tril_attention_mask
|
|
|
|
self.metric_loss = torchmetrics.MeanMetric()
|
2024-03-05 22:08:37 +08:00
|
|
|
self.vocab_size = self.llm.config.vocab_size
|
2024-02-25 20:20:32 +08:00
|
|
|
self.metric_accuracy = torchmetrics.Accuracy(
|
|
|
|
task="multiclass",
|
2024-03-05 22:08:37 +08:00
|
|
|
num_classes=self.vocab_size,
|
2024-02-25 20:20:32 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
@cache
|
|
|
|
def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor:
|
|
|
|
matrix = torch.ones(block_size, block_size).tril()
|
|
|
|
if batch_size is not None:
|
|
|
|
matrix = matrix.repeat(batch_size, 1, 1)
|
|
|
|
return matrix
|
|
|
|
|
|
|
|
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
|
|
|
object.__setattr__(self, "__core_module__", module)
|
|
|
|
return module
|
|
|
|
|
|
|
|
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
|
|
|
batch_size, block_size = batch["input_ids"].shape
|
|
|
|
if self.use_tril_attention_mask:
|
|
|
|
batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
|
|
|
|
outputs, loss = self.llm(**batch)
|
|
|
|
self.log("train_loss", loss, rank_zero_only=True)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
|
|
|
outputs, loss = self.llm(**batch, return_dict=True)
|
|
|
|
logits = outputs[..., :-1, :]
|
2024-03-05 22:08:37 +08:00
|
|
|
logits = logits.contiguous().view(-1, logits.size(-1))
|
2024-02-25 20:20:32 +08:00
|
|
|
labels = batch["labels"][..., 1:]
|
2024-03-05 22:08:37 +08:00
|
|
|
labels = labels.contiguous().view(-1)
|
|
|
|
label_mask = labels < self.vocab_size
|
2024-03-08 20:46:42 +08:00
|
|
|
logits_m = logits[label_mask]
|
|
|
|
labels_m = labels[label_mask]
|
2024-03-07 16:30:37 +08:00
|
|
|
# m = torch.max(logits, 1).indices.cpu().numpy()
|
|
|
|
# ll = labels.cpu().numpy()
|
2024-03-08 20:46:42 +08:00
|
|
|
self.metric_accuracy.update(logits_m, labels_m)
|
2024-02-25 20:20:32 +08:00
|
|
|
self.metric_loss.update(loss)
|
|
|
|
|
|
|
|
def on_validation_epoch_end(self) -> None:
|
|
|
|
self.log("val_loss", self.metric_loss, rank_zero_only=True)
|
|
|
|
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
def configure_callbacks(self):
|
|
|
|
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
|
|
|
monitor="accuracy",
|
|
|
|
mode="max",
|
|
|
|
filename="{epoch:02d}-{accuracy:.4f}",
|
|
|
|
)
|
|
|
|
early_stop_callback = pl.callbacks.EarlyStopping(
|
|
|
|
monitor="accuracy",
|
|
|
|
min_delta=0.001,
|
|
|
|
patience=3,
|
|
|
|
mode="max",
|
|
|
|
stopping_threshold=1,
|
|
|
|
)
|
2024-03-07 16:30:37 +08:00
|
|
|
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
|
2024-03-29 22:10:25 +08:00
|
|
|
return [lr_monitor]
|
|
|
|
# return [checkpoint_callback, lr_monitor]
|
2024-02-26 22:42:50 +08:00
|
|
|
# return [checkpoint_callback, early_stop_callback]
|