from typing import Optional, Tuple, Union, Callable, List, Any, Generator from einops import rearrange import torch import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from torch import nn class QWenModel(nn.Module): class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) return norm.type_as(x) * self.weight class Block(nn.Module): class Attention(nn.Module): def __init__(self, config, index): super().__init__() self.hidden_size = config.hidden_size self.split_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.c_attn = nn.Linear(config.hidden_size, 3 * self.hidden_size) self.c_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=not config.no_bias) self.attn_dropout = nn.Dropout(config.attn_dropout_prob) self.index = index def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) return tensor def _merge_heads(self, tensor, num_heads, attn_head_size): tensor = tensor.contiguous() new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) return tensor.view(new_shape) class MLP(nn.Module): def __init__(self, config): super().__init__() ff_dim_in = config.intermediate_size // 2 self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias) self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias) self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) def __init__(self, config, index): super().__init__() self.ln_1 = QWenModel.RMSNorm( config.hidden_size, eps=config.layer_norm_epsilon, ) self.attn = QWenModel.Block.Attention(config, index) self.ln_2 = QWenModel.RMSNorm( config.hidden_size, eps=config.layer_norm_epsilon, ) self.mlp = QWenModel.Block.MLP(config) self.index = index def __init__(self, config): super().__init__() self.wte = nn.Embedding(config.vocab_size, config.hidden_size) self.drop = nn.Dropout(config.emb_dropout_prob) self.dim = config.hidden_size // config.num_attention_heads self.h = nn.ModuleList([QWenModel.Block(config, i) for i in range(config.num_hidden_layers)]) self.ln_f = QWenModel.RMSNorm( config.hidden_size, eps=config.layer_norm_epsilon, ) self.base = config.rotary_emb_base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._rotary_pos_emb_cache = None self._seq_len_cached = 0 self._ntk_alpha_cached = 1.0 def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0): if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) self.inv_freq = 1.0 / ( base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim) ) self._seq_len_cached = max(2 * seqlen, 16) self._ntk_alpha_cached = ntk_alpha seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) emb = rearrange(emb, "n d -> 1 n 1 d") cos, sin = emb.cos(), emb.sin() self._rotary_pos_emb_cache = [cos, sin] class QWenLMHeadModel(nn.Module): def __init__(self, config): super().__init__() self.config = config self.transformer = QWenModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.hook_attention = None def apply_rotary_pos_emb(self, t, freqs): rot_dim = freqs[0].shape[-1] cos, sin = freqs t_float = t.float() t_rot = t_float[..., :rot_dim] t_pass = t_float[..., rot_dim:] x = rearrange(t_rot, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-2) _rotate_half = torch.cat((-x2, x1), dim=-1) t_rot = (t_rot * cos) + (_rotate_half * sin) return torch.cat((t_rot, t_pass), dim=-1).type_as(t) def forward( self, input_ids: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, ): transfm = self.transformer input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = transfm.wte(input_ids) kv_seq_len = hidden_states.size()[1] transfm.update_rotary_pos_emb_cache(kv_seq_len, ntk_alpha=1.0) cos, sin = transfm._rotary_pos_emb_cache rotary_pos_emb_list = [[cos[:, :kv_seq_len], sin[:, :kv_seq_len]]] hidden_states = transfm.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) for index, block in enumerate(transfm.h): layernorm_output = block.ln_1(hidden_states) # split_heads atten = block.attn mixed_x_layer = atten.c_attn(layernorm_output) query, key, value = mixed_x_layer.split(atten.split_size, dim=2) query = atten._split_heads(query, atten.num_heads, atten.head_dim) key = atten._split_heads(key, atten.num_heads, atten.head_dim) value = atten._split_heads(value, atten.num_heads, atten.head_dim) # pos_emb rotary_pos_emb = rotary_pos_emb_list[0] rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb] rotary_pos_emb = (rotary_pos_emb,) * 2 query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0]) key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1]) # build_mask size = query.size(1) causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view( 1, 1, size, size ) # attention q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2) if self.hook_attention: self.hook_attention(query, key, causal_mask, index) context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim) attn_outputs = block.attn.c_proj(context_layer) layernorm_input = attn_outputs + hidden_states layernorm_output = block.ln_2(layernorm_input) a1 = block.mlp.w1(layernorm_output) a2 = block.mlp.w2(layernorm_output) intermediate_parallel = a1 * F.silu(a2) mlp_output = block.mlp.c_proj(intermediate_parallel) hidden_states = layernorm_input + mlp_output hidden_states = transfm.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: labels = labels.to(lm_logits.device) shift_labels = labels[..., 1:].contiguous().view(-1) shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = shift_logits.view(-1, shift_logits.size(-1)) mask = shift_labels < self.config.vocab_size shift_labels = shift_labels[mask] shift_logits = shift_logits[mask] loss = CrossEntropyLoss()(shift_logits, shift_labels) return lm_logits, loss