from functools import cache from typing import Dict, Optional import pytorch_lightning as pl import torch import torchmetrics from model.modeling_wit import QWenLMHeadModel from configuration import ModelConfig, TrainConfig class QwenModule(pl.LightningModule): def __init__(self, conf: TrainConfig = None): pretrained_model_dir = conf.pretrain_model_name learning_rate = conf.learning_rate mconf = conf.model_config use_tril_attention_mask = conf.use_tril_attention_mask super().__init__() self.save_hyperparameters() if mconf == None: mconf = ModelConfig() model = QWenLMHeadModel(mconf) if pretrained_model_dir != None: from modelscope import snapshot_download model = model.from_pretrained(snapshot_download(pretrained_model_dir)) self.llm = self.register_core_module(model) self.learning_rate = learning_rate self.use_tril_attention_mask = use_tril_attention_mask self.metric_loss = torchmetrics.MeanMetric() self.vocab_size = self.llm.config.vocab_size self.metric_accuracy = torchmetrics.Accuracy( task="multiclass", num_classes=self.vocab_size, ) @cache def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor: matrix = torch.ones(block_size, block_size).tril() if batch_size is not None: matrix = matrix.repeat(batch_size, 1, 1) return matrix def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: object.__setattr__(self, "__core_module__", module) return module def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): batch_size, block_size = batch["input_ids"].shape if self.use_tril_attention_mask: batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) outputs, loss = self.llm(**batch) self.log("train_loss", loss, rank_zero_only=True) return loss def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): outputs, loss = self.llm(**batch, return_dict=True) logits = outputs[..., :-1, :] logits = logits.contiguous().view(-1, logits.size(-1)) labels = batch["labels"][..., 1:] labels = labels.contiguous().view(-1) if "val_mask" in batch and batch["val_mask"] != None: label_mask = batch["val_mask"][..., 1:] label_mask = label_mask.contiguous().view(-1) logits = logits[label_mask] labels = labels[label_mask] if logits.numel() != 0 and labels.numel() != 0: self.metric_accuracy.update(logits, labels) self.metric_loss.update(loss) def on_validation_epoch_end(self) -> None: self.log("val_loss", self.metric_loss, rank_zero_only=True) self.log("accuracy", self.metric_accuracy, rank_zero_only=True) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate) return optimizer def configure_callbacks(self): checkpoint_callback = pl.callbacks.ModelCheckpoint( monitor="accuracy", mode="max", filename="{epoch:02d}-{accuracy:.4f}", ) early_stop_callback = pl.callbacks.EarlyStopping( monitor="accuracy", min_delta=0.001, patience=3, mode="max", stopping_threshold=1, ) lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step") return [lr_monitor] # return [checkpoint_callback, lr_monitor] # return [checkpoint_callback, early_stop_callback]