import os import gc import json from tqdm import auto as tqdm_lib from torch import nn from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file from functools import cache from typing import Dict, Optional, Union import pytorch_lightning as pl import torch import torchmetrics from model.modeling_wit import QWenLMHeadModel from configuration import ModelConfig, TrainConfig class LoadModule: def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]): pretrained_model_name_or_path = str(pretrained_model_name_or_path) resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") print(f"loading weights file {resolved_archive_file}") with open(resolved_archive_file, "r") as f: index = json.loads(f.read()) shard_filenames = sorted(set(index["weight_map"].values())) resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames] model = LoadModule._load_pretrained_model(cls, resolved_archive_file) return model def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata error_msgs = [] def load(module: nn.Module, state_dict, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) if len([key for key in state_dict if key.startswith(prefix)]) > 0: module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".") load(model_to_load, state_dict, prefix=start_prefix) del state_dict return error_msgs def _load_pretrained_model(cls, resolved_archive_file): start_prefix = "" model_to_load = cls if len(resolved_archive_file) > 1: resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards") for shard_file in resolved_archive_file: state_dict = safe_load_file(shard_file) LoadModule._load_state_dict_into_model(model_to_load, state_dict, start_prefix) del state_dict # force memory release gc.collect() print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n") return cls class ModelRunner: def __init__(self, qwen): self.qwen = qwen @torch.no_grad() def ChatTokens(self, input_ids, sample=True): qwen = self.qwen input_ids = input_ids.to(next(qwen.parameters()).device) outputs, loss = qwen.forward(input_ids) next_token_scores = outputs[:, -1, :] next_token_scores = self.repetition_penalty(input_ids, next_token_scores) if sample: next_token_scores = self.top_p(next_token_scores) return self.sample(next_token_scores) else: return torch.sort(next_token_scores, descending=True) @torch.no_grad() def Chat( self, tokenizer, query: str, query_assistant: str, gen_length=0, system: str = "You are a helpful assistant.", history=[], ): qwen = self.qwen history = copy.deepcopy(history) self.qwen.config.pad_token_id = tokenizer.eod_id self.qwen.config.eos_token_id = tokenizer.eod_id raw_text, context_tokens = qwen.prepareInput(tokenizer, query, query_assistant, history, system) input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device) self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) input_length = input_ids.shape[1] while True: outputs, loss = self.forward(input_ids) next_token_scores = outputs[:, -1, :] next_token_scores = self.repetition_penalty(input_ids, next_token_scores) next_token_scores = self.top_p(next_token_scores) next_tokens = self.sample(next_token_scores) finish, next_tokens = self.isFinish(next_tokens) if finish: break input_ids = torch.cat([input_ids, next_tokens], dim=-1) if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]: break decoded, response, end_reason = decode_tokens( input_ids[0], tokenizer, raw_text_len=len(raw_text), context_length=len(context_tokens), errors="replace", ) history.append((query, response)) return input_ids[0].cpu().tolist(), history, decoded def prepareInput(self, tokenizer, query, query_assistant, history, system): return make_context(tokenizer, query, query_assistant, history=history, system=system) def repetition_penalty(self, input_ids, next_token_scores): penalty = self.qwen.config.repetition_penalty score = torch.gather(next_token_scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities score = torch.where(score < 0, score * penalty, score / penalty) next_token_scores = next_token_scores.scatter_(1, input_ids, score) return next_token_scores def top_p(self, next_token_scores): top_p = self.qwen.config.top_p filter_value = -float("Inf") min_tokens_to_keep = 1 sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs <= (1 - top_p) # Keep at least min_tokens_to_keep sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value) return next_token_scores def sample(self, next_token_scores): probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) return next_tokens def isFinish(self, next_tokens): pad_token_id = self.qwen.config.pad_token_id eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device) next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences) self.unfinished_sequences = self.unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) return self.unfinished_sequences.max() == 0, next_tokens[:, None] class QwenModule(pl.LightningModule): def __init__(self, conf: TrainConfig = None): self.config = conf pretrained_model_dir = conf.pretrain_model_name learning_rate = conf.learning_rate mconf = conf.model_config use_tril_attention_mask = conf.use_tril_attention_mask super().__init__() self.save_hyperparameters() if mconf == None: mconf = ModelConfig() model = QWenLMHeadModel(mconf) if pretrained_model_dir != None: from modelscope import snapshot_download model = LoadModule.from_pretrained(snapshot_download(pretrained_model_dir)) self.llm = self.register_core_module(model) self.learning_rate = learning_rate self.use_tril_attention_mask = use_tril_attention_mask self.metric_loss = torchmetrics.MeanMetric() self.vocab_size = self.llm.config.vocab_size self.metric_accuracy = torchmetrics.Accuracy( task="multiclass", num_classes=self.vocab_size, ) @cache def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor: matrix = torch.ones(block_size, block_size).tril() if batch_size is not None: matrix = matrix.repeat(batch_size, 1, 1) return matrix def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module: object.__setattr__(self, "__core_module__", module) return module def training_step(self, batch: Dict[str, torch.Tensor], batch_idx): batch_size, block_size = batch["input_ids"].shape if self.use_tril_attention_mask: batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device) outputs, loss = self.llm(**batch) self.log("train_loss", loss, rank_zero_only=True) return loss def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx): outputs, loss = self.llm(**batch, return_dict=True) logits = outputs[..., :-1, :] logits = logits.contiguous().view(-1, logits.size(-1)) labels = batch["labels"][..., 1:] labels = labels.contiguous().view(-1) if "val_mask" in batch and batch["val_mask"] != None: label_mask = batch["val_mask"][..., 1:] label_mask = label_mask.contiguous().view(-1) logits = logits[label_mask] labels = labels[label_mask] if logits.numel() != 0 and labels.numel() != 0: self.metric_accuracy.update(logits, labels) self.metric_loss.update(loss) def on_validation_epoch_end(self) -> None: self.log("val_loss", self.metric_loss, rank_zero_only=True) self.log("accuracy", self.metric_accuracy, rank_zero_only=True) self.log("hp_metric", self.metric_accuracy, rank_zero_only=True) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate) return optimizer def configure_callbacks(self): checkpoint_callback = pl.callbacks.ModelCheckpoint( monitor="accuracy", mode="max", filename="{epoch:02d}-{accuracy:.4f}", ) early_stop_callback = pl.callbacks.EarlyStopping( monitor="accuracy", min_delta=0.001, patience=3, mode="max", stopping_threshold=1, ) lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step") return [lr_monitor] # return [checkpoint_callback, lr_monitor] # return [checkpoint_callback, early_stop_callback]