Witllm/wit/model/qwen_module.py

252 lines
11 KiB
Python

import os
import gc
import json
from tqdm import auto as tqdm_lib
from torch import nn
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from functools import cache
from typing import Dict, Optional, Union
import pytorch_lightning as pl
import torch
import torchmetrics
from configuration import ModelConfig, TrainConfig
class LoadModule:
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
model = LoadModule._load_pretrained_model(cls, resolved_archive_file)
return model
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
def _load_pretrained_model(cls, resolved_archive_file):
start_prefix = ""
model_to_load = cls
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
for shard_file in resolved_archive_file:
state_dict = safe_load_file(shard_file)
LoadModule._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
del state_dict # force memory release
gc.collect()
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
return cls
class ModelRunner:
def __init__(self, qwen):
self.qwen = qwen
@torch.no_grad()
def ChatTokens(self, input_ids, sample=True):
qwen = self.qwen
input_ids = input_ids.to(next(qwen.parameters()).device)
outputs, loss = qwen.forward(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
if sample:
next_token_scores = self.top_p(next_token_scores)
return self.sample(next_token_scores)
else:
return torch.sort(next_token_scores, descending=True)
@torch.no_grad()
def Chat(
self,
tokenizer,
query: str,
query_assistant: str,
gen_length=0,
system: str = "You are a helpful assistant.",
history=[],
):
qwen = self.qwen
history = copy.deepcopy(history)
self.qwen.config.pad_token_id = tokenizer.eod_id
self.qwen.config.eos_token_id = tokenizer.eod_id
raw_text, context_tokens = qwen.prepareInput(tokenizer, query, query_assistant, history, system)
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
input_length = input_ids.shape[1]
while True:
outputs, loss = self.forward(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
next_token_scores = self.top_p(next_token_scores)
next_tokens = self.sample(next_token_scores)
finish, next_tokens = self.isFinish(next_tokens)
if finish:
break
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
break
decoded, response, end_reason = decode_tokens(
input_ids[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
errors="replace",
)
history.append((query, response))
return input_ids[0].cpu().tolist(), history, decoded
def prepareInput(self, tokenizer, query, query_assistant, history, system):
return make_context(tokenizer, query, query_assistant, history=history, system=system)
def repetition_penalty(self, input_ids, next_token_scores):
penalty = self.qwen.config.repetition_penalty
score = torch.gather(next_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
return next_token_scores
def top_p(self, next_token_scores):
top_p = self.qwen.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
return next_token_scores
def sample(self, next_token_scores):
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens
def isFinish(self, next_tokens):
pad_token_id = self.qwen.config.pad_token_id
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device)
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
self.unfinished_sequences = self.unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
return self.unfinished_sequences.max() == 0, next_tokens[:, None]
class QwenModule(pl.LightningModule):
def __init__(self, conf: TrainConfig, model):
self.config = conf
pretrained_model_dir = conf.pretrain_model_name
learning_rate = conf.learning_rate
mconf = conf.model_config
use_tril_attention_mask = conf.use_tril_attention_mask
super().__init__()
self.save_hyperparameters()
if mconf == None:
mconf = ModelConfig()
if pretrained_model_dir != None:
from modelscope import snapshot_download
model = LoadModule.from_pretrained(snapshot_download(pretrained_model_dir))
self.llm = self.register_core_module(model)
self.learning_rate = learning_rate
self.use_tril_attention_mask = use_tril_attention_mask
self.metric_loss = torchmetrics.MeanMetric()
self.vocab_size = self.llm.config.vocab_size
self.metric_accuracy = torchmetrics.Accuracy(
task="multiclass",
num_classes=self.vocab_size,
)
@cache
def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor:
matrix = torch.ones(block_size, block_size).tril()
if batch_size is not None:
matrix = matrix.repeat(batch_size, 1, 1)
return matrix
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
object.__setattr__(self, "__core_module__", module)
return module
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
batch_size, block_size = batch["input_ids"].shape
if self.use_tril_attention_mask:
batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
outputs, loss = self.llm(**batch)
self.log("train_loss", loss, rank_zero_only=True)
return loss
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
outputs, loss = self.llm(**batch, return_dict=True)
logits = outputs[..., :-1, :]
logits = logits.contiguous().view(-1, logits.size(-1))
labels = batch["labels"][..., 1:]
labels = labels.contiguous().view(-1)
if "val_mask" in batch and batch["val_mask"] != None:
label_mask = batch["val_mask"][..., 1:]
label_mask = label_mask.contiguous().view(-1)
logits = logits[label_mask]
labels = labels[label_mask]
if logits.numel() != 0 and labels.numel() != 0:
self.metric_accuracy.update(logits, labels)
self.metric_loss.update(loss)
def on_validation_epoch_end(self) -> None:
self.log("val_loss", self.metric_loss, rank_zero_only=True)
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
self.log("hp_metric", self.metric_accuracy, rank_zero_only=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
return optimizer
def configure_callbacks(self):
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="accuracy",
mode="max",
filename="{epoch:02d}-{accuracy:.4f}",
)
early_stop_callback = pl.callbacks.EarlyStopping(
monitor="accuracy",
min_delta=0.001,
patience=3,
mode="max",
stopping_threshold=1,
)
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
return [lr_monitor]
# return [checkpoint_callback, lr_monitor]
# return [checkpoint_callback, early_stop_callback]