| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  | import copy | 
					
						
							|  |  |  | import math | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | import gc | 
					
						
							|  |  |  | from tqdm import auto as tqdm_lib | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | from typing import Optional, Tuple, Union, Callable, List, Any, Generator | 
					
						
							|  |  |  | from einops import rearrange | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torch.nn.functional as F | 
					
						
							|  |  |  | import torch.utils.checkpoint | 
					
						
							|  |  |  | from torch.nn import CrossEntropyLoss | 
					
						
							|  |  |  | from torch import nn | 
					
						
							|  |  |  | from safetensors.torch import load_file as safe_load_file | 
					
						
							|  |  |  | from safetensors.torch import save_file as safe_save_file | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from qwen_generation_utils import ( | 
					
						
							|  |  |  |     make_context, | 
					
						
							|  |  |  |     decode_tokens, | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | sys.path.append("..") | 
					
						
							|  |  |  | from tools import show | 
					
						
							|  |  |  | from tools import mem_tracker | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # tracker = mem_tracker.MemTracker() | 
					
						
							|  |  |  | # tracker.track() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RMSNorm(torch.nn.Module): | 
					
						
							|  |  |  |     def __init__(self, dim: int, eps: float = 1e-6): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.eps = eps | 
					
						
							|  |  |  |         self.weight = nn.Parameter(torch.ones(dim)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _norm(self, x): | 
					
						
							|  |  |  |         return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forward(self, x): | 
					
						
							|  |  |  |         return self._norm(x.float()).type_as(x) * self.weight | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QWenAttention(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, config, index): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.hidden_size = config.hidden_size | 
					
						
							|  |  |  |         self.split_size = config.hidden_size | 
					
						
							|  |  |  |         self.num_heads = config.num_attention_heads | 
					
						
							|  |  |  |         self.head_dim = self.hidden_size // self.num_heads | 
					
						
							| 
									
										
										
										
											2024-03-04 21:41:46 +08:00
										 |  |  |         self.c_attn = nn.Linear(config.hidden_size, 3 * self.hidden_size) | 
					
						
							|  |  |  |         self.c_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=not config.no_bias) | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  |         self.attn_dropout = nn.Dropout(config.attn_dropout_prob) | 
					
						
							|  |  |  |         self.index = index | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _split_heads(self, tensor, num_heads, attn_head_size): | 
					
						
							|  |  |  |         new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) | 
					
						
							|  |  |  |         tensor = tensor.view(new_shape) | 
					
						
							|  |  |  |         return tensor | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _merge_heads(self, tensor, num_heads, attn_head_size): | 
					
						
							|  |  |  |         tensor = tensor.contiguous() | 
					
						
							|  |  |  |         new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) | 
					
						
							|  |  |  |         return tensor.view(new_shape) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QWenMLP(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, config): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         ff_dim_in = config.intermediate_size // 2 | 
					
						
							|  |  |  |         self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias) | 
					
						
							|  |  |  |         self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias) | 
					
						
							|  |  |  |         self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QWenBlock(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, config, index): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.ln_1 = RMSNorm( | 
					
						
							|  |  |  |             config.hidden_size, | 
					
						
							|  |  |  |             eps=config.layer_norm_epsilon, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.attn = QWenAttention(config, index) | 
					
						
							|  |  |  |         self.ln_2 = RMSNorm( | 
					
						
							|  |  |  |             config.hidden_size, | 
					
						
							|  |  |  |             eps=config.layer_norm_epsilon, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.mlp = QWenMLP(config) | 
					
						
							|  |  |  |         self.index = index | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QWenModel(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, config): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.wte = nn.Embedding(config.vocab_size, config.hidden_size) | 
					
						
							|  |  |  |         self.drop = nn.Dropout(config.emb_dropout_prob) | 
					
						
							| 
									
										
										
										
											2024-03-04 21:41:46 +08:00
										 |  |  |         self.dim = config.hidden_size // config.num_attention_heads | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         self.h = nn.ModuleList([QWenBlock(config, i) for i in range(config.num_hidden_layers)]) | 
					
						
							|  |  |  |         self.ln_f = RMSNorm( | 
					
						
							|  |  |  |             config.hidden_size, | 
					
						
							|  |  |  |             eps=config.layer_norm_epsilon, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         self.base = config.rotary_emb_base | 
					
						
							| 
									
										
										
										
											2024-03-04 21:41:46 +08:00
										 |  |  |         inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  |         self.register_buffer("inv_freq", inv_freq, persistent=False) | 
					
						
							|  |  |  |         self._rotary_pos_emb_cache = None | 
					
						
							|  |  |  |         self._seq_len_cached = 0 | 
					
						
							|  |  |  |         self._ntk_alpha_cached = 1.0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0): | 
					
						
							|  |  |  |         if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: | 
					
						
							|  |  |  |             base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) | 
					
						
							|  |  |  |             self.inv_freq = 1.0 / ( | 
					
						
							|  |  |  |                 base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             self._seq_len_cached = max(2 * seqlen, 16) | 
					
						
							|  |  |  |             self._ntk_alpha_cached = ntk_alpha | 
					
						
							|  |  |  |             seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) | 
					
						
							|  |  |  |             freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             emb = torch.cat((freqs, freqs), dim=-1) | 
					
						
							|  |  |  |             emb = rearrange(emb, "n d -> 1 n 1 d") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             cos, sin = emb.cos(), emb.sin() | 
					
						
							|  |  |  |             self._rotary_pos_emb_cache = [cos, sin] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QWenLMHeadModel(nn.Module): | 
					
						
							|  |  |  |     def __init__(self, config): | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         self.config = config | 
					
						
							|  |  |  |         self.transformer = QWenModel(config) | 
					
						
							|  |  |  |         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-25 20:20:32 +08:00
										 |  |  |     def forward( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         input_ids: Optional[torch.LongTensor] = None, | 
					
						
							|  |  |  |         labels: Optional[torch.LongTensor] = None, | 
					
						
							|  |  |  |         token_type_ids: Optional[torch.LongTensor] = None, | 
					
						
							|  |  |  |         **kwargs, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         runner = QwenRunner(self) | 
					
						
							|  |  |  |         return runner.forwardQWen(input_ids, labels) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  |     def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]): | 
					
						
							|  |  |  |         pretrained_model_name_or_path = str(pretrained_model_name_or_path) | 
					
						
							|  |  |  |         resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") | 
					
						
							|  |  |  |         print(f"loading weights file {resolved_archive_file}") | 
					
						
							|  |  |  |         with open(resolved_archive_file, "r") as f: | 
					
						
							|  |  |  |             index = json.loads(f.read()) | 
					
						
							|  |  |  |         shard_filenames = sorted(set(index["weight_map"].values())) | 
					
						
							|  |  |  |         resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames] | 
					
						
							|  |  |  |         model = cls._load_pretrained_model(resolved_archive_file) | 
					
						
							|  |  |  |         return model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix): | 
					
						
							|  |  |  |         metadata = getattr(state_dict, "_metadata", None) | 
					
						
							|  |  |  |         state_dict = state_dict.copy() | 
					
						
							|  |  |  |         if metadata is not None: | 
					
						
							|  |  |  |             state_dict._metadata = metadata | 
					
						
							|  |  |  |         error_msgs = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def load(module: nn.Module, state_dict, prefix=""): | 
					
						
							|  |  |  |             local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | 
					
						
							|  |  |  |             args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) | 
					
						
							|  |  |  |             if len([key for key in state_dict if key.startswith(prefix)]) > 0: | 
					
						
							|  |  |  |                 module._load_from_state_dict(*args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for name, child in module._modules.items(): | 
					
						
							|  |  |  |                 if child is not None: | 
					
						
							|  |  |  |                     load(child, state_dict, prefix + name + ".") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         load(model_to_load, state_dict, prefix=start_prefix) | 
					
						
							|  |  |  |         del state_dict | 
					
						
							|  |  |  |         return error_msgs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _load_pretrained_model(cls, resolved_archive_file): | 
					
						
							|  |  |  |         start_prefix = "" | 
					
						
							|  |  |  |         model_to_load = cls | 
					
						
							|  |  |  |         if len(resolved_archive_file) > 1: | 
					
						
							|  |  |  |             resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards") | 
					
						
							|  |  |  |         for shard_file in resolved_archive_file: | 
					
						
							|  |  |  |             state_dict = safe_load_file(shard_file) | 
					
						
							|  |  |  |             cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix) | 
					
						
							|  |  |  |             del state_dict  # force memory release | 
					
						
							|  |  |  |             gc.collect() | 
					
						
							|  |  |  |         print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n") | 
					
						
							|  |  |  |         return cls | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class QwenRunner: | 
					
						
							|  |  |  |     def __init__(self, qwen): | 
					
						
							|  |  |  |         self.qwen = qwen | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @torch.no_grad() | 
					
						
							|  |  |  |     def Chat( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         tokenizer, | 
					
						
							|  |  |  |         query: str, | 
					
						
							|  |  |  |         query_assistant: str, | 
					
						
							| 
									
										
										
										
											2024-02-06 14:08:45 +08:00
										 |  |  |         gen_length=0, | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  |         system: str = "You are a helpful assistant.", | 
					
						
							|  |  |  |         history=[], | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         qwen = self.qwen | 
					
						
							|  |  |  |         history = copy.deepcopy(history) | 
					
						
							| 
									
										
										
										
											2024-02-06 14:08:45 +08:00
										 |  |  |         self.qwen.config.pad_token_id = tokenizer.eod_id | 
					
						
							|  |  |  |         self.qwen.config.eos_token_id = tokenizer.eod_id | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  |         raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system) | 
					
						
							|  |  |  |         input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device) | 
					
						
							|  |  |  |         self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | 
					
						
							| 
									
										
										
										
											2024-02-06 14:08:45 +08:00
										 |  |  |         input_length = input_ids.shape[1] | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  |         while True: | 
					
						
							|  |  |  |             outputs = self.forwardQWen(input_ids) | 
					
						
							|  |  |  |             next_token_scores = outputs[:, -1, :] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             next_token_scores = self.repetition_penalty(input_ids, next_token_scores) | 
					
						
							|  |  |  |             next_token_scores = self.top_p(next_token_scores) | 
					
						
							|  |  |  |             next_tokens = self.sample(next_token_scores) | 
					
						
							|  |  |  |             finish, next_tokens = self.isFinish(next_tokens) | 
					
						
							|  |  |  |             if finish: | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |             input_ids = torch.cat([input_ids, next_tokens], dim=-1) | 
					
						
							| 
									
										
										
										
											2024-02-06 14:08:45 +08:00
										 |  |  |             if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]: | 
					
						
							|  |  |  |                 break | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         decoded, response, end_reason = decode_tokens( | 
					
						
							|  |  |  |             input_ids[0], | 
					
						
							|  |  |  |             tokenizer, | 
					
						
							|  |  |  |             raw_text_len=len(raw_text), | 
					
						
							|  |  |  |             context_length=len(context_tokens), | 
					
						
							|  |  |  |             errors="replace", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         history.append((query, response)) | 
					
						
							|  |  |  |         return input_ids[0].cpu().tolist(), history, decoded | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _rotate_half(self, x): | 
					
						
							|  |  |  |         x = rearrange(x, "... (j d) -> ... j d", j=2) | 
					
						
							|  |  |  |         x1, x2 = x.unbind(dim=-2) | 
					
						
							|  |  |  |         return torch.cat((-x2, x1), dim=-1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def apply_rotary_pos_emb(self, t, freqs): | 
					
						
							|  |  |  |         rot_dim = freqs[0].shape[-1] | 
					
						
							|  |  |  |         cos, sin = freqs | 
					
						
							|  |  |  |         t_float = t.float() | 
					
						
							| 
									
										
										
										
											2024-03-07 16:30:37 +08:00
										 |  |  |         t_rot = t_float[..., :rot_dim] | 
					
						
							|  |  |  |         t_pass = t_float[..., rot_dim:] | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  |         t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin) | 
					
						
							|  |  |  |         return torch.cat((t_rot, t_pass), dim=-1).type_as(t) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def split_heads( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         attention, | 
					
						
							|  |  |  |         hidden_states: Optional[Tuple[torch.FloatTensor]], | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         atten = attention | 
					
						
							|  |  |  |         mixed_x_layer = atten.c_attn(hidden_states) | 
					
						
							|  |  |  |         query, key, value = mixed_x_layer.split(atten.split_size, dim=2) | 
					
						
							|  |  |  |         query = atten._split_heads(query, atten.num_heads, atten.head_dim) | 
					
						
							|  |  |  |         key = atten._split_heads(key, atten.num_heads, atten.head_dim) | 
					
						
							|  |  |  |         value = atten._split_heads(value, atten.num_heads, atten.head_dim) | 
					
						
							|  |  |  |         return query, key, value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def pos_emb(self, query, key, rotary_pos_emb_list): | 
					
						
							|  |  |  |         rotary_pos_emb = rotary_pos_emb_list[0] | 
					
						
							|  |  |  |         rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb] | 
					
						
							|  |  |  |         rotary_pos_emb = (rotary_pos_emb,) * 2 | 
					
						
							|  |  |  |         query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0]) | 
					
						
							|  |  |  |         key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1]) | 
					
						
							|  |  |  |         return query, key | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def attention(self, attention, query, key, value, causal_mask): | 
					
						
							|  |  |  |         query = query.permute(0, 2, 1, 3) | 
					
						
							|  |  |  |         key = key.permute(0, 2, 1, 3) | 
					
						
							|  |  |  |         value = value.permute(0, 2, 1, 3) | 
					
						
							|  |  |  |         attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) | 
					
						
							|  |  |  |         context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim) | 
					
						
							|  |  |  |         attn_output = attention.c_proj(context_layer) | 
					
						
							|  |  |  |         return attn_output | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def build_mask(self, query): | 
					
						
							|  |  |  |         size = query.size(1) | 
					
						
							|  |  |  |         causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size) | 
					
						
							|  |  |  |         return causal_mask | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forwardAttention( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         attention, | 
					
						
							|  |  |  |         hidden_states: Optional[Tuple[torch.FloatTensor]], | 
					
						
							|  |  |  |         rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         query, key, value = self.split_heads(attention, hidden_states) | 
					
						
							|  |  |  |         query, key = self.pos_emb(query, key, rotary_pos_emb_list) | 
					
						
							|  |  |  |         causal_mask = self.build_mask(query) | 
					
						
							|  |  |  |         return self.attention(attention, query, key, value, causal_mask) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forwardQWenBlock( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         block, | 
					
						
							|  |  |  |         hidden_states: Optional[Tuple[torch.FloatTensor]], | 
					
						
							|  |  |  |         rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         layernorm_output = block.ln_1(hidden_states) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         attn_outputs = self.forwardAttention(block.attn, layernorm_output, rotary_pos_emb_list) | 
					
						
							|  |  |  |         attn_output = attn_outputs[0] | 
					
						
							|  |  |  |         layernorm_input = attn_output + hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         layernorm_output = block.ln_2(layernorm_input) | 
					
						
							|  |  |  |         a1 = block.mlp.w1(layernorm_output) | 
					
						
							|  |  |  |         a2 = block.mlp.w2(layernorm_output) | 
					
						
							|  |  |  |         intermediate_parallel = a1 * F.silu(a2) | 
					
						
							|  |  |  |         mlp_output = block.mlp.c_proj(intermediate_parallel) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         hidden_states = layernorm_input + mlp_output | 
					
						
							|  |  |  |         return hidden_states | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def forwardQWen( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         input_ids: Optional[torch.LongTensor] = None, | 
					
						
							|  |  |  |         labels: Optional[torch.LongTensor] = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         transfm = self.qwen.transformer | 
					
						
							|  |  |  |         input_shape = input_ids.size() | 
					
						
							|  |  |  |         input_ids = input_ids.view(-1, input_shape[-1]) | 
					
						
							|  |  |  |         hidden_states = transfm.wte(input_ids) | 
					
						
							|  |  |  |         kv_seq_len = hidden_states.size()[1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         transfm.update_rotary_pos_emb_cache(kv_seq_len, ntk_alpha=1.0) | 
					
						
							|  |  |  |         cos, sin = transfm._rotary_pos_emb_cache | 
					
						
							|  |  |  |         rotary_pos_emb_list = [[cos[:, :kv_seq_len], sin[:, :kv_seq_len]]] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         hidden_states = transfm.drop(hidden_states) | 
					
						
							|  |  |  |         output_shape = input_shape + (hidden_states.size(-1),) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for block in transfm.h: | 
					
						
							|  |  |  |             hidden_states = self.forwardQWenBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         hidden_states = transfm.ln_f(hidden_states) | 
					
						
							|  |  |  |         hidden_states = hidden_states.view(output_shape) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         lm_logits = self.qwen.lm_head(hidden_states) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         loss = None | 
					
						
							|  |  |  |         if labels is not None: | 
					
						
							|  |  |  |             labels = labels.to(lm_logits.device) | 
					
						
							| 
									
										
										
										
											2024-03-04 21:41:46 +08:00
										 |  |  |             shift_labels = labels[..., 1:].contiguous().view(-1) | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  |             shift_logits = lm_logits[..., :-1, :].contiguous() | 
					
						
							| 
									
										
										
										
											2024-03-04 21:41:46 +08:00
										 |  |  |             shift_logits = shift_logits.view(-1, shift_logits.size(-1)) | 
					
						
							| 
									
										
										
										
											2024-03-05 22:08:37 +08:00
										 |  |  |             mask = shift_labels < self.qwen.config.vocab_size | 
					
						
							| 
									
										
										
										
											2024-03-04 21:41:46 +08:00
										 |  |  |             shift_labels = shift_labels[mask] | 
					
						
							|  |  |  |             shift_logits = shift_logits[mask] | 
					
						
							| 
									
										
										
										
											2024-03-07 16:30:37 +08:00
										 |  |  |             # m = torch.max(shift_logits, 1).indices.cpu().numpy() | 
					
						
							|  |  |  |             # ll = shift_labels.cpu().numpy() | 
					
						
							| 
									
										
										
										
											2024-03-05 22:08:37 +08:00
										 |  |  |             loss = CrossEntropyLoss()(shift_logits, shift_labels) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-25 20:20:32 +08:00
										 |  |  |         return lm_logits, loss | 
					
						
							| 
									
										
										
										
											2024-02-04 23:48:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def prepareInput(self, tokenizer, query, query_assistant, history, system): | 
					
						
							|  |  |  |         return make_context(tokenizer, query, query_assistant, history=history, system=system) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def repetition_penalty(self, input_ids, next_token_scores): | 
					
						
							|  |  |  |         penalty = self.qwen.config.repetition_penalty | 
					
						
							|  |  |  |         score = torch.gather(next_token_scores, 1, input_ids) | 
					
						
							|  |  |  |         # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities | 
					
						
							|  |  |  |         score = torch.where(score < 0, score * penalty, score / penalty) | 
					
						
							|  |  |  |         next_token_scores = next_token_scores.scatter_(1, input_ids, score) | 
					
						
							|  |  |  |         return next_token_scores | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def top_p(self, next_token_scores): | 
					
						
							|  |  |  |         top_p = self.qwen.config.top_p | 
					
						
							|  |  |  |         filter_value = -float("Inf") | 
					
						
							|  |  |  |         min_tokens_to_keep = 1 | 
					
						
							|  |  |  |         sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False) | 
					
						
							|  |  |  |         cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | 
					
						
							|  |  |  |         # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | 
					
						
							|  |  |  |         sorted_indices_to_remove = cumulative_probs <= (1 - top_p) | 
					
						
							|  |  |  |         # Keep at least min_tokens_to_keep | 
					
						
							|  |  |  |         sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 | 
					
						
							|  |  |  |         # scatter sorted tensors to original indexing | 
					
						
							|  |  |  |         indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | 
					
						
							|  |  |  |         next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value) | 
					
						
							|  |  |  |         return next_token_scores | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def sample(self, next_token_scores): | 
					
						
							|  |  |  |         probs = nn.functional.softmax(next_token_scores, dim=-1) | 
					
						
							|  |  |  |         next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | 
					
						
							|  |  |  |         return next_tokens | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def isFinish(self, next_tokens): | 
					
						
							|  |  |  |         pad_token_id = self.qwen.config.pad_token_id | 
					
						
							|  |  |  |         eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences) | 
					
						
							|  |  |  |         self.unfinished_sequences = self.unfinished_sequences.mul( | 
					
						
							|  |  |  |             next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         return self.unfinished_sequences.max() == 0, next_tokens[:, None] |