# Copyright (c) Alibaba Cloud. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import copy import math import inspect from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator import torch import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList from transformers.generation.logits_process import LogitsProcessorList if TYPE_CHECKING: from transformers.generation.streamers import BaseStreamer from transformers.generation.utils import GenerateOutput from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from torch import nn from einops import rearrange from configuration_qwen import QWenConfig from qwen_generation_utils import ( HistoryType, make_context, decode_tokens, StopWordsLogitsProcessor, ) import sys sys.path.append("..") from tools import show from tools import mem_tracker # tracker = mem_tracker.MemTracker() # tracker.track() class QWenAttention(nn.Module): def __init__(self, config, index): super().__init__() self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) self.seq_length = config.seq_length self.hidden_size = config.hidden_size self.split_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.scale_attn_weights = True self.projection_size = config.kv_channels * config.num_attention_heads assert self.projection_size % config.num_attention_heads == 0 self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias) self.use_dynamic_ntk = config.use_dynamic_ntk logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)] logn_tensor = torch.tensor(logn_list)[None, :, None, None] self.register_buffer("logn_tensor", logn_tensor, persistent=False) self.attn_dropout = nn.Dropout(config.attn_dropout_prob) self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False cache_dtype = torch.float self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) self.index = index def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) return tensor def _merge_heads(self, tensor, num_heads, attn_head_size): tensor = tensor.contiguous() new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) return tensor.view(new_shape) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, ): mixed_x_layer = self.c_attn(hidden_states) query, key, value = mixed_x_layer.split(self.split_size, dim=2) query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) rotary_pos_emb = rotary_pos_emb_list[0] rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb] rotary_pos_emb = (rotary_pos_emb,) * 2 q_pos_emb, k_pos_emb = rotary_pos_emb # Slice the pos emb for current inference query = apply_rotary_pos_emb(query, q_pos_emb) key = apply_rotary_pos_emb(key, k_pos_emb) key_size = key.size(1) if key_size > self.seq_length and not self.training: seq_start = key.size(1) - query.size(1) seq_end = key.size(1) logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) key_size = key.size(1) if query.size(1) == key_size: causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view( 1, 1, key_size, key_size ) else: causal_mask = None query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) if attention_mask is not None: attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1) if causal_mask is not None: attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) else: attention_mask = causal_mask # qk = query @ key.transpose(-2, -1) # qk = qk[0] # show.DumpTensorToImage(qk,"q_matmul_k_layer_"+str(self.index)+".png") # print("layer:" + str(self.index) + " query.shape:"+ str(query.shape)) # print("layer:" + str(self.index) + " key.shape:"+ str(key.shape)) # print("layer:" + str(self.index) + " value.shape:"+ str(value.shape)) # print("\n") attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(context_layer) return attn_output class QWenMLP(nn.Module): def __init__(self, config): super().__init__() ff_dim_in = config.intermediate_size // 2 self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias) self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias) self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) def forward(self, hidden_states): a1 = self.w1(hidden_states) a2 = self.w2(hidden_states) intermediate_parallel = a1 * F.silu(a2) output = self.c_proj(intermediate_parallel) return output class QWenBlock(nn.Module): def __init__(self, config, index): super().__init__() hidden_size = config.hidden_size self.ln_1 = RMSNorm( hidden_size, eps=config.layer_norm_epsilon, ) self.attn = QWenAttention(config, index) self.ln_2 = RMSNorm( hidden_size, eps=config.layer_norm_epsilon, ) self.mlp = QWenMLP(config) self.index = index def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, ): layernorm_output = self.ln_1(hidden_states) attn_outputs = self.attn( layernorm_output, rotary_pos_emb_list, attention_mask=attention_mask, ) attn_output = attn_outputs[0] residual = hidden_states layernorm_input = attn_output + residual layernorm_output = self.ln_2(layernorm_input) residual = layernorm_input mlp_output = self.mlp(layernorm_output) hidden_states = residual + mlp_output return hidden_states class QWenPreTrainedModel(PreTrainedModel): config_class = QWenConfig base_model_prefix = "transformer" is_parallelizable = False supports_gradient_checkpointing = True _no_split_modules = ["QWenBlock"] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) class QWenModel(QWenPreTrainedModel): def __init__(self, config): super().__init__(config) self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.embed_dim = config.hidden_size self.use_dynamic_ntk = config.use_dynamic_ntk self.seq_length = config.seq_length self.wte = nn.Embedding(self.vocab_size, self.embed_dim) self.drop = nn.Dropout(config.emb_dropout_prob) if config.rotary_pct == 1.0: self.rotary_ndims = None else: assert config.rotary_pct < 1 self.rotary_ndims = int(config.kv_channels * config.rotary_pct) dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) self.h = nn.ModuleList([QWenBlock(config, i) for i in range(config.num_hidden_layers)]) self.ln_f = RMSNorm( self.embed_dim, eps=config.layer_norm_epsilon, ) self.post_init() def get_ntk_alpha(self, true_seq_len): context_value = math.log(true_seq_len / self.seq_length, 2) + 1 ntk_alpha = 2 ** math.ceil(context_value) - 1 ntk_alpha = max(ntk_alpha, 1) return ntk_alpha def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, ): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if attention_mask is not None: attention_mask = attention_mask.view(batch_size, -1) attention_mask = attention_mask[:, None, None, :] attention_mask = attention_mask.to(dtype=self.dtype) attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) hidden_states = inputs_embeds kv_seq_len = hidden_states.size()[1] if self.training or not self.use_dynamic_ntk: ntk_alpha_list = [1.0] elif kv_seq_len != hidden_states.size()[1]: ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list else: ntk_alpha_list = [] if attention_mask is not None and kv_seq_len > self.seq_length: true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) for i in range(hidden_states.size()[0]): true_seq_len = true_seq_lens[i].item() ntk_alpha = self.get_ntk_alpha(true_seq_len) ntk_alpha_list.append(ntk_alpha) else: ntk_alpha = self.get_ntk_alpha(kv_seq_len) ntk_alpha_list.append(ntk_alpha) self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list] hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) all_hidden_states = None for block in self.h: hidden_states = block( hidden_states, rotary_pos_emb_list=rotary_pos_emb_list, attention_mask=attention_mask, ) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) return BaseModelOutputWithPast(last_hidden_state=hidden_states, hidden_states=all_hidden_states) class QWenLMHeadModel(QWenPreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = QWenModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, **kwargs): if input_ids.size(0) == 1: attention_mask = None else: attention_mask = kwargs.get("attention_mask", None) model_inputs = {"input_ids": input_ids} model_inputs.update( { "attention_mask": attention_mask, } ) return model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: transformer_outputs = self.transformer( input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: labels = labels.to(lm_logits.device) shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64) # shift_logits = lm_logits[..., :-1, :].contiguous() # loss_fct = CrossEntropyLoss() # loss = loss_fct( # shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) # ) # loss.backward() return CausalLMOutputWithPast( loss=loss, logits=lm_logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @torch.no_grad() def chat( self, tokenizer: PreTrainedTokenizer, query: str, query_assistant: str, history: Optional[HistoryType], system: str = "You are a helpful assistant.", **kwargs, ) -> Tuple[str, HistoryType]: generation_config = self.generation_config if history is None: history = [] else: history = copy.deepcopy(history) stop_words_ids = [] max_window_size = kwargs.get("max_window_size", None) if max_window_size is None: max_window_size = generation_config.max_window_size raw_text, context_tokens = make_context( tokenizer, query, query_assistant, history=history, system=system, max_window_size=max_window_size ) stop_words_ids.extend([[tokenizer.im_end_id], [tokenizer.im_start_id]]) input_ids = torch.tensor([context_tokens]).to(self.device) outputs = self.generate( input_ids, stop_words_ids=stop_words_ids, **kwargs, ) decoded, response, end_reason = decode_tokens( outputs[0], tokenizer, raw_text_len=len(raw_text), context_length=len(context_tokens), errors="replace", ) history.append((query, response)) return response, history, decoded def generate( self, inputs: Optional[torch.Tensor] = None, stop_words_ids=[], prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: generation_config = self.generation_config # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) # 2. Set generation parameters if not already defined if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: eos_token_id = generation_config.eos_token_id if isinstance(eos_token_id, list): eos_token_id = eos_token_id[0] generation_config.pad_token_id = eos_token_id # 3. Define model inputs inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) # 4. Define other model kwargs accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id, ) # 5. Prepare `input_ids` which will be used for auto-regressive generation input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None generation_config.max_length = generation_config.max_new_tokens + input_ids_length self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) stop_words_logits_processor = StopWordsLogitsProcessor( stop_words_ids=stop_words_ids, eos_token_id=generation_config.eos_token_id, ) logits_processor = LogitsProcessorList([stop_words_logits_processor]) logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, model_kwargs=model_kwargs, negative_prompt_ids=None, negative_prompt_attention_mask=None, ) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, is_encoder_decoder=False, **model_kwargs, ) # 13. run sample pad_token_id = generation_config.pad_token_id eos_token_id_tensor = torch.tensor([generation_config.eos_token_id]).to(input_ids.device) # init values stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=StoppingCriteriaList() ) logits_warper = self._get_logits_warper(generation_config) # init attention / hidden states / scores tuples scores = None # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # auto-regressive generation while True: # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # forward pass to get next token outputs = self(**model_inputs) next_token_scores = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False) unfinished_sequences = unfinished_sequences.mul( next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) ) # stop when each sentence is finished if unfinished_sequences.max() == 0: this_peer_finished = True # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True if this_peer_finished: break return input_ids class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() self.dim = dim self.base = base inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._rotary_pos_emb_cache = None self._seq_len_cached = 0 self._ntk_alpha_cached = 1.0 self._ntk_alpha_cached_list = [1.0] def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0): if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) self.inv_freq = 1.0 / ( base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim) ) self._seq_len_cached = max(2 * seqlen, 16) self._ntk_alpha_cached = ntk_alpha seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) emb = rearrange(emb, "n d -> 1 n 1 d") cos, sin = emb.cos(), emb.sin() self._rotary_pos_emb_cache = [cos, sin] def forward(self, max_seq_len, ntk_alpha=1.0): self.update_rotary_pos_emb_cache(max_seq_len, ntk_alpha) cos, sin = self._rotary_pos_emb_cache return [cos[:, :max_seq_len], sin[:, :max_seq_len]] def _rotate_half(x): x = rearrange(x, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(t, freqs): rot_dim = freqs[0].shape[-1] cos, sin = freqs t_float = t.float() t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin) return torch.cat((t_rot, t_pass), dim=-1).type_as(t) class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight