from functools import cache from typing import Dict, Optional import pytorch_lightning as pl import torch import torchmetrics from modeling_wit import QWenLMHeadModel from wit.configuration import ModelConfig from transformers import AutoConfig class LitModule(pl.LightningModule): def __init__( self, pretrained_model_dir: str = None, learning_rate: float = 0.0001, config: ModelConfig = None, use_tril_attention_mask: str = False, ): super().__init__() self.save_hyperparameters() if config == None: config = ModelConfig() model = QWenLMHeadModel(config) if pretrained_model_dir != None: from modelscope import snapshot_download model = model.from_pretrained(snapshot_download(pretrained_model_dir)) self.llm = self.register_core_module(model) self.learning_rate = learning_rate self.use_tril_attention_mask = use_tril_attention_mask self.metric_loss = torchmetrics.MeanMetric() self.vocab_size = self.llm.config.vocab_size self.metric_accuracy = torchmetrics.Accuracy( task="multiclass", num_classes=self.vocab_size, ) @cache def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor: matrix = torch.ones(block_size, block_size).tril() if batch_size is not None: matrix = matrix.repeat(batch_size, 1, 1) return matrix def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: object.__setattr__(self, "__core_module__", module) return module def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): batch_size, block_size = batch["input_ids"].shape if self.use_tril_attention_mask: batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) outputs, loss = self.llm(**batch) self.log("train_loss", loss, rank_zero_only=True) return loss def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): outputs, loss = self.llm(**batch, return_dict=True) logits = outputs[..., :-1, :] logits = logits.contiguous().view(-1, logits.size(-1)) labels = batch["labels"][..., 1:] labels = labels.contiguous().view(-1) label_mask = labels < self.vocab_size logits_m = logits[label_mask] labels_m = labels[label_mask] # m = torch.max(logits, 1).indices.cpu().numpy() # ll = labels.cpu().numpy() self.metric_accuracy.update(logits_m, labels_m) self.metric_loss.update(loss) def on_validation_epoch_end(self) -> None: self.log("val_loss", self.metric_loss, rank_zero_only=True) self.log("accuracy", self.metric_accuracy, rank_zero_only=True) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate) return optimizer def configure_callbacks(self): checkpoint_callback = pl.callbacks.ModelCheckpoint( monitor="accuracy", mode="max", filename="{epoch:02d}-{accuracy:.4f}", ) early_stop_callback = pl.callbacks.EarlyStopping( monitor="accuracy", min_delta=0.001, patience=3, mode="max", stopping_threshold=1, ) lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step") return [lr_monitor] # return [checkpoint_callback, lr_monitor] # return [checkpoint_callback, early_stop_callback]