Witllm/wit/model/modeling_wit.py

205 lines
8.4 KiB
Python

from typing import Optional, Tuple, Union, Callable, List, Any, Generator
from einops import rearrange
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from torch import nn
class QWenModel(nn.Module):
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return norm.type_as(x) * self.weight
class Block(nn.Module):
class Attention(nn.Module):
def __init__(self, config, index):
super().__init__()
self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.c_attn = nn.Linear(config.hidden_size, 3 * self.hidden_size)
self.c_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=not config.no_bias)
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.index = index
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
ff_dim_in = config.intermediate_size // 2
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
def __init__(self, config, index):
super().__init__()
self.ln_1 = QWenModel.RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = QWenModel.Block.Attention(config, index)
self.ln_2 = QWenModel.RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenModel.Block.MLP(config)
self.index = index
def __init__(self, config):
super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.drop = nn.Dropout(config.emb_dropout_prob)
self.dim = config.hidden_size // config.num_attention_heads
self.h = nn.ModuleList([QWenModel.Block(config, i) for i in range(config.num_hidden_layers)])
self.ln_f = QWenModel.RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.base = config.rotary_emb_base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._rotary_pos_emb_cache = None
self._seq_len_cached = 0
self._ntk_alpha_cached = 1.0
def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0):
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim)
)
self._seq_len_cached = max(2 * seqlen, 16)
self._ntk_alpha_cached = ntk_alpha
seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
emb = rearrange(emb, "n d -> 1 n 1 d")
cos, sin = emb.cos(), emb.sin()
self._rotary_pos_emb_cache = [cos, sin]
class QWenLMHeadModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hook_attention = None
def apply_rotary_pos_emb(self, t, freqs):
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_float = t.float()
t_rot = t_float[..., :rot_dim]
t_pass = t_float[..., rot_dim:]
x = rearrange(t_rot, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
_rotate_half = torch.cat((-x2, x1), dim=-1)
t_rot = (t_rot * cos) + (_rotate_half * sin)
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
transfm = self.transformer
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = transfm.wte(input_ids)
kv_seq_len = hidden_states.size()[1]
transfm.update_rotary_pos_emb_cache(kv_seq_len, ntk_alpha=1.0)
cos, sin = transfm._rotary_pos_emb_cache
rotary_pos_emb_list = [[cos[:, :kv_seq_len], sin[:, :kv_seq_len]]]
hidden_states = transfm.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
for index, block in enumerate(transfm.h):
layernorm_output = block.ln_1(hidden_states)
# split_heads
atten = block.attn
mixed_x_layer = atten.c_attn(layernorm_output)
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
# pos_emb
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
# build_mask
size = query.size(1)
causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(
1, 1, size, size
)
# attention
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
if self.hook_attention:
self.hook_attention(query, key, causal_mask, index)
context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim)
attn_outputs = block.attn.c_proj(context_layer)
layernorm_input = attn_outputs + hidden_states
layernorm_output = block.ln_2(layernorm_input)
a1 = block.mlp.w1(layernorm_output)
a2 = block.mlp.w2(layernorm_output)
intermediate_parallel = a1 * F.silu(a2)
mlp_output = block.mlp.c_proj(intermediate_parallel)
hidden_states = layernorm_input + mlp_output
hidden_states = transfm.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_labels = labels[..., 1:].contiguous().view(-1)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
mask = shift_labels < self.config.vocab_size
shift_labels = shift_labels[mask]
shift_logits = shift_logits[mask]
loss = CrossEntropyLoss()(shift_logits, shift_labels)
return lm_logits, loss