254 lines
11 KiB
Python
254 lines
11 KiB
Python
import os
|
|
import gc
|
|
import json
|
|
from tqdm import auto as tqdm_lib
|
|
from torch import nn
|
|
from safetensors.torch import load_file as safe_load_file
|
|
from safetensors.torch import save_file as safe_save_file
|
|
|
|
from functools import cache
|
|
from typing import Dict, Optional, Union
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
import torchmetrics
|
|
|
|
from model.modeling_wit import QWenLMHeadModel
|
|
from configuration import ModelConfig, TrainConfig
|
|
|
|
|
|
class LoadModule:
|
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
|
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
|
|
print(f"loading weights file {resolved_archive_file}")
|
|
with open(resolved_archive_file, "r") as f:
|
|
index = json.loads(f.read())
|
|
shard_filenames = sorted(set(index["weight_map"].values()))
|
|
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
|
|
model = LoadModule._load_pretrained_model(cls, resolved_archive_file)
|
|
return model
|
|
|
|
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
|
|
metadata = getattr(state_dict, "_metadata", None)
|
|
state_dict = state_dict.copy()
|
|
if metadata is not None:
|
|
state_dict._metadata = metadata
|
|
error_msgs = []
|
|
|
|
def load(module: nn.Module, state_dict, prefix=""):
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
|
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
|
|
module._load_from_state_dict(*args)
|
|
|
|
for name, child in module._modules.items():
|
|
if child is not None:
|
|
load(child, state_dict, prefix + name + ".")
|
|
|
|
load(model_to_load, state_dict, prefix=start_prefix)
|
|
del state_dict
|
|
return error_msgs
|
|
|
|
def _load_pretrained_model(cls, resolved_archive_file):
|
|
start_prefix = ""
|
|
model_to_load = cls
|
|
if len(resolved_archive_file) > 1:
|
|
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
|
|
for shard_file in resolved_archive_file:
|
|
state_dict = safe_load_file(shard_file)
|
|
LoadModule._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
|
|
del state_dict # force memory release
|
|
gc.collect()
|
|
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
|
|
return cls
|
|
|
|
|
|
class ModelRunner:
|
|
def __init__(self, qwen):
|
|
self.qwen = qwen
|
|
|
|
@torch.no_grad()
|
|
def ChatTokens(self, input_ids, sample=True):
|
|
qwen = self.qwen
|
|
input_ids = input_ids.to(next(qwen.parameters()).device)
|
|
outputs, loss = qwen.forward(input_ids)
|
|
next_token_scores = outputs[:, -1, :]
|
|
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
|
if sample:
|
|
next_token_scores = self.top_p(next_token_scores)
|
|
return self.sample(next_token_scores)
|
|
else:
|
|
return torch.sort(next_token_scores, descending=True)
|
|
|
|
@torch.no_grad()
|
|
def Chat(
|
|
self,
|
|
tokenizer,
|
|
query: str,
|
|
query_assistant: str,
|
|
gen_length=0,
|
|
system: str = "You are a helpful assistant.",
|
|
history=[],
|
|
):
|
|
qwen = self.qwen
|
|
history = copy.deepcopy(history)
|
|
self.qwen.config.pad_token_id = tokenizer.eod_id
|
|
self.qwen.config.eos_token_id = tokenizer.eod_id
|
|
raw_text, context_tokens = qwen.prepareInput(tokenizer, query, query_assistant, history, system)
|
|
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
|
|
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
|
input_length = input_ids.shape[1]
|
|
while True:
|
|
outputs, loss = self.forward(input_ids)
|
|
next_token_scores = outputs[:, -1, :]
|
|
|
|
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
|
next_token_scores = self.top_p(next_token_scores)
|
|
next_tokens = self.sample(next_token_scores)
|
|
finish, next_tokens = self.isFinish(next_tokens)
|
|
if finish:
|
|
break
|
|
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
|
if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
|
|
break
|
|
|
|
decoded, response, end_reason = decode_tokens(
|
|
input_ids[0],
|
|
tokenizer,
|
|
raw_text_len=len(raw_text),
|
|
context_length=len(context_tokens),
|
|
errors="replace",
|
|
)
|
|
history.append((query, response))
|
|
return input_ids[0].cpu().tolist(), history, decoded
|
|
|
|
def prepareInput(self, tokenizer, query, query_assistant, history, system):
|
|
return make_context(tokenizer, query, query_assistant, history=history, system=system)
|
|
|
|
def repetition_penalty(self, input_ids, next_token_scores):
|
|
penalty = self.qwen.config.repetition_penalty
|
|
score = torch.gather(next_token_scores, 1, input_ids)
|
|
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
|
score = torch.where(score < 0, score * penalty, score / penalty)
|
|
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
|
|
return next_token_scores
|
|
|
|
def top_p(self, next_token_scores):
|
|
top_p = self.qwen.config.top_p
|
|
filter_value = -float("Inf")
|
|
min_tokens_to_keep = 1
|
|
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
|
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
|
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
|
# Keep at least min_tokens_to_keep
|
|
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
|
|
# scatter sorted tensors to original indexing
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
|
|
return next_token_scores
|
|
|
|
def sample(self, next_token_scores):
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
return next_tokens
|
|
|
|
def isFinish(self, next_tokens):
|
|
pad_token_id = self.qwen.config.pad_token_id
|
|
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device)
|
|
|
|
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
|
|
self.unfinished_sequences = self.unfinished_sequences.mul(
|
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
|
)
|
|
return self.unfinished_sequences.max() == 0, next_tokens[:, None]
|
|
|
|
|
|
class QwenModule(pl.LightningModule):
|
|
def __init__(self, conf: TrainConfig = None):
|
|
self.config = conf
|
|
pretrained_model_dir = conf.pretrain_model_name
|
|
learning_rate = conf.learning_rate
|
|
mconf = conf.model_config
|
|
use_tril_attention_mask = conf.use_tril_attention_mask
|
|
super().__init__()
|
|
self.save_hyperparameters()
|
|
if mconf == None:
|
|
mconf = ModelConfig()
|
|
model = QWenLMHeadModel(mconf)
|
|
if pretrained_model_dir != None:
|
|
from modelscope import snapshot_download
|
|
|
|
model = LoadModule.from_pretrained(snapshot_download(pretrained_model_dir))
|
|
self.llm = self.register_core_module(model)
|
|
self.learning_rate = learning_rate
|
|
self.use_tril_attention_mask = use_tril_attention_mask
|
|
self.metric_loss = torchmetrics.MeanMetric()
|
|
self.vocab_size = self.llm.config.vocab_size
|
|
self.metric_accuracy = torchmetrics.Accuracy(
|
|
task="multiclass",
|
|
num_classes=self.vocab_size,
|
|
)
|
|
|
|
@cache
|
|
def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor:
|
|
matrix = torch.ones(block_size, block_size).tril()
|
|
if batch_size is not None:
|
|
matrix = matrix.repeat(batch_size, 1, 1)
|
|
return matrix
|
|
|
|
def register_core_module(self, module: torch.nn.Module) -> torch.nn.Module:
|
|
object.__setattr__(self, "__core_module__", module)
|
|
return module
|
|
|
|
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
|
batch_size, block_size = batch["input_ids"].shape
|
|
if self.use_tril_attention_mask:
|
|
batch["attention_mask"] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
|
|
outputs, loss = self.llm(**batch)
|
|
self.log("train_loss", loss, rank_zero_only=True)
|
|
return loss
|
|
|
|
def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
|
outputs, loss = self.llm(**batch, return_dict=True)
|
|
logits = outputs[..., :-1, :]
|
|
logits = logits.contiguous().view(-1, logits.size(-1))
|
|
labels = batch["labels"][..., 1:]
|
|
labels = labels.contiguous().view(-1)
|
|
if "val_mask" in batch and batch["val_mask"] != None:
|
|
label_mask = batch["val_mask"][..., 1:]
|
|
label_mask = label_mask.contiguous().view(-1)
|
|
logits = logits[label_mask]
|
|
labels = labels[label_mask]
|
|
if logits.numel() != 0 and labels.numel() != 0:
|
|
self.metric_accuracy.update(logits, labels)
|
|
self.metric_loss.update(loss)
|
|
|
|
def on_validation_epoch_end(self) -> None:
|
|
self.log("val_loss", self.metric_loss, rank_zero_only=True)
|
|
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
|
|
self.log("hp_metric", self.metric_accuracy, rank_zero_only=True)
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
|
|
return optimizer
|
|
|
|
def configure_callbacks(self):
|
|
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
|
monitor="accuracy",
|
|
mode="max",
|
|
filename="{epoch:02d}-{accuracy:.4f}",
|
|
)
|
|
early_stop_callback = pl.callbacks.EarlyStopping(
|
|
monitor="accuracy",
|
|
min_delta=0.001,
|
|
patience=3,
|
|
mode="max",
|
|
stopping_threshold=1,
|
|
)
|
|
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step")
|
|
return [lr_monitor]
|
|
# return [checkpoint_callback, lr_monitor]
|
|
# return [checkpoint_callback, early_stop_callback]
|