Refine model define.

This commit is contained in:
Colin 2025-02-28 13:16:39 +08:00
parent bff65b189d
commit e3b63f4635
7 changed files with 2351 additions and 329 deletions

1036
wit/90800.ini Normal file

File diff suppressed because it is too large Load Diff

1036
wit/Untitled-1.ini Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,26 +1,23 @@
import pytorch_lightning as pl
import torch import torch
from model.qwen_module import QwenModule from model.qwen_module import QwenModule
from model.modeling_wit import QwenRunner from model.qwen_module import ModelRunner
from model.tokenization_qwen import QWenTokenizer
import numpy as np import numpy as np
import configuration
import dataset.dataset as ds import dataset.dataset as ds
import dataset.node_tree as nt
if __name__ == "__main__": if __name__ == "__main__":
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval() qwen.eval()
conf = qwen.config conf = qwen.config
torch.manual_seed(conf.seed) torch.manual_seed(conf.seed)
np.random.seed(conf.seed) np.random.seed(conf.seed)
runner = QwenRunner(qwen.llm) runner = ModelRunner(qwen.llm)
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64) # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
@ -43,4 +40,4 @@ if __name__ == "__main__":
if item[i] != next_token: if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token)) node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR") print(str(item[i]) + " " + str(next_token) + " ERROR")
node.print() # node.print()

View File

@ -1,10 +1,3 @@
import copy
import math
import os
import sys
import gc
from tqdm import auto as tqdm_lib
import json
from typing import Optional, Tuple, Union, Callable, List, Any, Generator from typing import Optional, Tuple, Union, Callable, List, Any, Generator
from einops import rearrange from einops import rearrange
@ -13,92 +6,73 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch import nn from torch import nn
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from model.qwen_generation_utils import (
make_context,
decode_tokens,
)
sys.path.append("..")
from tools import show
from tools import mem_tracker
# tracker = mem_tracker.MemTracker()
# tracker.track()
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
class QWenAttention(nn.Module):
def __init__(self, config, index):
super().__init__()
self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.c_attn = nn.Linear(config.hidden_size, 3 * self.hidden_size)
self.c_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=not config.no_bias)
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.index = index
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
class QWenMLP(nn.Module):
def __init__(self, config):
super().__init__()
ff_dim_in = config.intermediate_size // 2
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
class QWenBlock(nn.Module):
def __init__(self, config, index):
super().__init__()
self.ln_1 = RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = QWenAttention(config, index)
self.ln_2 = RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenMLP(config)
self.index = index
class QWenModel(nn.Module): class QWenModel(nn.Module):
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return norm.type_as(x) * self.weight
class Block(nn.Module):
class Attention(nn.Module):
def __init__(self, config, index):
super().__init__()
self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.c_attn = nn.Linear(config.hidden_size, 3 * self.hidden_size)
self.c_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=not config.no_bias)
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.index = index
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
ff_dim_in = config.intermediate_size // 2
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
def __init__(self, config, index):
super().__init__()
self.ln_1 = QWenModel.RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = QWenModel.Block.Attention(config, index)
self.ln_2 = QWenModel.RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenModel.Block.MLP(config)
self.index = index
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.hidden_size) self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.drop = nn.Dropout(config.emb_dropout_prob) self.drop = nn.Dropout(config.emb_dropout_prob)
self.dim = config.hidden_size // config.num_attention_heads self.dim = config.hidden_size // config.num_attention_heads
self.h = nn.ModuleList([QWenBlock(config, i) for i in range(config.num_hidden_layers)]) self.h = nn.ModuleList([QWenModel.Block(config, i) for i in range(config.num_hidden_layers)])
self.ln_f = RMSNorm( self.ln_f = QWenModel.RMSNorm(
config.hidden_size, config.hidden_size,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
@ -141,201 +115,7 @@ class QWenLMHeadModel(nn.Module):
token_type_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
): ):
runner = QwenRunner(self) transfm = self.transformer
return runner.forwardQWen(input_ids, labels)
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
model = cls._load_pretrained_model(resolved_archive_file)
return model
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
def _load_pretrained_model(cls, resolved_archive_file):
start_prefix = ""
model_to_load = cls
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
for shard_file in resolved_archive_file:
state_dict = safe_load_file(shard_file)
cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
del state_dict # force memory release
gc.collect()
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
return cls
class QwenRunner:
def __init__(self, qwen):
self.qwen = qwen
# torch.backends.cuda.enable_flash_sdp(True)
@torch.no_grad()
def ChatTokens(self, input_ids, sample=True):
qwen = self.qwen
input_ids = input_ids.to(next(qwen.parameters()).device)
outputs, loss = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
if sample:
next_token_scores = self.top_p(next_token_scores)
return self.sample(next_token_scores)
else:
return torch.sort(next_token_scores, descending=True)
@torch.no_grad()
def Chat(
self,
tokenizer,
query: str,
query_assistant: str,
gen_length=0,
system: str = "You are a helpful assistant.",
history=[],
):
qwen = self.qwen
history = copy.deepcopy(history)
self.qwen.config.pad_token_id = tokenizer.eod_id
self.qwen.config.eos_token_id = tokenizer.eod_id
raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
input_length = input_ids.shape[1]
while True:
outputs, loss = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
next_token_scores = self.top_p(next_token_scores)
next_tokens = self.sample(next_token_scores)
finish, next_tokens = self.isFinish(next_tokens)
if finish:
break
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
break
decoded, response, end_reason = decode_tokens(
input_ids[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
errors="replace",
)
history.append((query, response))
return input_ids[0].cpu().tolist(), history, decoded
def _rotate_half(self, x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, t, freqs):
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_float = t.float()
t_rot = t_float[..., :rot_dim]
t_pass = t_float[..., rot_dim:]
t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin)
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
def split_heads(
self,
attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
):
atten = attention
mixed_x_layer = atten.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
return query, key, value
def pos_emb(self, query, key, rotary_pos_emb_list):
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
return query, key
def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
attn_output = attention.c_proj(context_layer)
return attn_output
def build_mask(self, query):
size = query.size(1)
causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size)
return causal_mask
def forwardAttention(
self,
attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
):
query, key, value = self.split_heads(attention, hidden_states)
query, key = self.pos_emb(query, key, rotary_pos_emb_list)
causal_mask = self.build_mask(query)
return self.attention(attention, query, key, value, causal_mask)
def forwardQWenBlock(
self,
block,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
):
layernorm_output = block.ln_1(hidden_states)
attn_outputs = self.forwardAttention(block.attn, layernorm_output, rotary_pos_emb_list)
layernorm_input = attn_outputs + hidden_states
layernorm_output = block.ln_2(layernorm_input)
a1 = block.mlp.w1(layernorm_output)
a2 = block.mlp.w2(layernorm_output)
intermediate_parallel = a1 * F.silu(a2)
mlp_output = block.mlp.c_proj(intermediate_parallel)
hidden_states = layernorm_input + mlp_output
return hidden_states
def forwardQWen(
self,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
):
transfm = self.qwen.transformer
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = transfm.wte(input_ids) hidden_states = transfm.wte(input_ids)
@ -349,12 +129,12 @@ class QwenRunner:
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
for block in transfm.h: for block in transfm.h:
hidden_states = self.forwardQWenBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list) hidden_states = self.forwardBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
hidden_states = transfm.ln_f(hidden_states) hidden_states = transfm.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
lm_logits = self.qwen.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
loss = None loss = None
if labels is not None: if labels is not None:
@ -362,7 +142,7 @@ class QwenRunner:
shift_labels = labels[..., 1:].contiguous().view(-1) shift_labels = labels[..., 1:].contiguous().view(-1)
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_logits = shift_logits.view(-1, shift_logits.size(-1))
mask = shift_labels < self.qwen.config.vocab_size mask = shift_labels < self.config.vocab_size
shift_labels = shift_labels[mask] shift_labels = shift_labels[mask]
shift_logits = shift_logits[mask] shift_logits = shift_logits[mask]
# m = torch.max(shift_logits, 1).indices.cpu().numpy() # m = torch.max(shift_logits, 1).indices.cpu().numpy()
@ -371,43 +151,60 @@ class QwenRunner:
return lm_logits, loss return lm_logits, loss
def prepareInput(self, tokenizer, query, query_assistant, history, system): def apply_rotary_pos_emb(self, t, freqs):
return make_context(tokenizer, query, query_assistant, history=history, system=system) rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_float = t.float()
t_rot = t_float[..., :rot_dim]
t_pass = t_float[..., rot_dim:]
def repetition_penalty(self, input_ids, next_token_scores): x = rearrange(t_rot, "... (j d) -> ... j d", j=2)
penalty = self.qwen.config.repetition_penalty x1, x2 = x.unbind(dim=-2)
score = torch.gather(next_token_scores, 1, input_ids) _rotate_half = torch.cat((-x2, x1), dim=-1)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
return next_token_scores
def top_p(self, next_token_scores): t_rot = (t_rot * cos) + (_rotate_half * sin)
top_p = self.qwen.config.top_p return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
return next_token_scores
def sample(self, next_token_scores): def forwardBlock(
probs = nn.functional.softmax(next_token_scores, dim=-1) self,
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) block,
return next_tokens hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
):
layernorm_output = block.ln_1(hidden_states)
def isFinish(self, next_tokens): # split_heads
pad_token_id = self.qwen.config.pad_token_id atten = block.attn
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device) mixed_x_layer = atten.c_attn(layernorm_output)
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
# pos_emb
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences) # build_mask
self.unfinished_sequences = self.unfinished_sequences.mul( size = query.size(1)
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size)
)
return self.unfinished_sequences.max() == 0, next_tokens[:, None] # attention
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim)
attn_outputs = block.attn.c_proj(context_layer)
layernorm_input = attn_outputs + hidden_states
layernorm_output = block.ln_2(layernorm_input)
a1 = block.mlp.w1(layernorm_output)
a2 = block.mlp.w2(layernorm_output)
intermediate_parallel = a1 * F.silu(a2)
mlp_output = block.mlp.c_proj(intermediate_parallel)
hidden_states = layernorm_input + mlp_output
return hidden_states

View File

@ -1,5 +1,13 @@
import os
import gc
import json
from tqdm import auto as tqdm_lib
from torch import nn
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from functools import cache from functools import cache
from typing import Dict, Optional from typing import Dict, Optional, Union
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
@ -9,6 +17,154 @@ from model.modeling_wit import QWenLMHeadModel
from configuration import ModelConfig, TrainConfig from configuration import ModelConfig, TrainConfig
class LoadModule:
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
model = LoadModule._load_pretrained_model(cls, resolved_archive_file)
return model
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
def _load_pretrained_model(cls, resolved_archive_file):
start_prefix = ""
model_to_load = cls
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
for shard_file in resolved_archive_file:
state_dict = safe_load_file(shard_file)
LoadModule._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
del state_dict # force memory release
gc.collect()
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
return cls
class ModelRunner:
def __init__(self, qwen):
self.qwen = qwen
@torch.no_grad()
def ChatTokens(self, input_ids, sample=True):
qwen = self.qwen
input_ids = input_ids.to(next(qwen.parameters()).device)
outputs, loss = qwen.forward(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
if sample:
next_token_scores = self.top_p(next_token_scores)
return self.sample(next_token_scores)
else:
return torch.sort(next_token_scores, descending=True)
@torch.no_grad()
def Chat(
self,
tokenizer,
query: str,
query_assistant: str,
gen_length=0,
system: str = "You are a helpful assistant.",
history=[],
):
qwen = self.qwen
history = copy.deepcopy(history)
self.qwen.config.pad_token_id = tokenizer.eod_id
self.qwen.config.eos_token_id = tokenizer.eod_id
raw_text, context_tokens = qwen.prepareInput(tokenizer, query, query_assistant, history, system)
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
input_length = input_ids.shape[1]
while True:
outputs, loss = self.forward(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
next_token_scores = self.top_p(next_token_scores)
next_tokens = self.sample(next_token_scores)
finish, next_tokens = self.isFinish(next_tokens)
if finish:
break
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
break
decoded, response, end_reason = decode_tokens(
input_ids[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
errors="replace",
)
history.append((query, response))
return input_ids[0].cpu().tolist(), history, decoded
def prepareInput(self, tokenizer, query, query_assistant, history, system):
return make_context(tokenizer, query, query_assistant, history=history, system=system)
def repetition_penalty(self, input_ids, next_token_scores):
penalty = self.qwen.config.repetition_penalty
score = torch.gather(next_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
return next_token_scores
def top_p(self, next_token_scores):
top_p = self.qwen.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
return next_token_scores
def sample(self, next_token_scores):
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens
def isFinish(self, next_tokens):
pad_token_id = self.qwen.config.pad_token_id
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device)
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
self.unfinished_sequences = self.unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
return self.unfinished_sequences.max() == 0, next_tokens[:, None]
class QwenModule(pl.LightningModule): class QwenModule(pl.LightningModule):
def __init__(self, conf: TrainConfig = None): def __init__(self, conf: TrainConfig = None):
self.config = conf self.config = conf
@ -24,7 +180,7 @@ class QwenModule(pl.LightningModule):
if pretrained_model_dir != None: if pretrained_model_dir != None:
from modelscope import snapshot_download from modelscope import snapshot_download
model = model.from_pretrained(snapshot_download(pretrained_model_dir)) model = LoadModule.from_pretrained(snapshot_download(pretrained_model_dir))
self.llm = self.register_core_module(model) self.llm = self.register_core_module(model)
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.use_tril_attention_mask = use_tril_attention_mask self.use_tril_attention_mask = use_tril_attention_mask

View File

@ -2,7 +2,7 @@ import pytorch_lightning as pl
import torch import torch
from model.qwen_module import QwenModule from model.qwen_module import QwenModule
from model.modeling_wit import QwenRunner from model.modeling_wit import ModelRunner
from model.tokenization_qwen import QWenTokenizer from model.tokenization_qwen import QWenTokenizer
import numpy as np import numpy as np

View File

@ -18,7 +18,7 @@ if __name__ == "__main__":
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
conf.learning_rate = 0.0001 conf.learning_rate = 0.0001
conf.use_tril_attention_mask = None conf.use_tril_attention_mask = None
conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
conf.train_batch_size = 16 conf.train_batch_size = 16
conf.val_batch_size = 2 conf.val_batch_size = 2
conf.num_proc = 8 conf.num_proc = 8
@ -38,7 +38,7 @@ if __name__ == "__main__":
config.vocab_size = 32 config.vocab_size = 32
config.hidden_size = 128 # 128 1024 2048 32 config.hidden_size = 128 # 128 1024 2048 32
config.num_hidden_layers = 3 # 6 12 24 3 config.num_hidden_layers = 3 # 6 12 24 3
config.num_attention_heads = 16 # 8 8 16 config.num_attention_heads = 8 # 8 8 16
torch.manual_seed(conf.seed) torch.manual_seed(conf.seed)
np.random.seed(conf.seed) np.random.seed(conf.seed)