Witllm/wit/model/modeling_rwkv7.py

273 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional, Tuple, Union, Callable, List, Any, Generator
from einops import rearrange
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from torch import nn
import torch.nn.init as init
# for 0.1B
n_layer = 3
n_embd = 256
D_DECAY_LORA = 64
D_AAA_LORA = 64
D_MV_LORA = 32
D_GATE_LORA = 128
dim_att = n_embd
dim_ffn = n_embd * 4
vocab_size = 32
# DTYPE = torch.bfloat16
DTYPE = torch.float32 # better
head_size_a = 64 # don't change
HS = head_size_a
class RWKV_Tmix_x070(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.head_size = head_size_a
self.n_head = dim_att // self.head_size
assert dim_att % self.n_head == 0
H = self.n_head
HS = self.head_size
C = n_embd
self.x_r = nn.Parameter(torch.empty(1, 1, C))
self.x_w = nn.Parameter(torch.empty(1, 1, C))
self.x_k = nn.Parameter(torch.empty(1, 1, C))
self.x_v = nn.Parameter(torch.empty(1, 1, C))
self.x_a = nn.Parameter(torch.empty(1, 1, C))
self.x_g = nn.Parameter(torch.empty(1, 1, C))
self.w0 = nn.Parameter(torch.empty(1, 1, C))
self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))
self.a0 = nn.Parameter(torch.empty(1, 1, C))
self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C))
self.v0 = nn.Parameter(torch.empty(1, 1, C))
self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA))
self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C))
self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C))
self.k_k = nn.Parameter(torch.empty(1, 1, C))
self.k_a = nn.Parameter(torch.empty(1, 1, C))
self.r_k = nn.Parameter(torch.empty(H, HS))
self.receptance = nn.Linear(C, C, bias=False)
self.key = nn.Linear(C, C, bias=False)
self.value = nn.Linear(C, C, bias=False)
self.output = nn.Linear(C, C, bias=False)
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
def forward(self, x, v_first):
B, T, C = x.size() # seq_len
H = self.n_head # 12
xx = torch.zeros_like(x) # time_shift [1, seq_len, 768] -> [1, seq_len, 768]
xx[:, 0, :] = -x[:, 0, :]
xx[:, 1:, :] = x[:, :-1, :] - x[:, 1:, :]
xr = x + xx * self.x_r # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xw = x + xx * self.x_w # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xk = x + xx * self.x_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xv = x + xx * self.x_v # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xa = x + xx * self.x_a # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
xg = x + xx * self.x_g # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
r = self.receptance(xr) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
xw = torch.tanh(xw @ self.w1) # -> [1, seq_len, 64]
xw = xw @ self.w2 + self.w0 # -> [1, seq_len, 768]
xw = -F.softplus(-xw) # 函数的输出范围为 [0,∞)
w = xw - 0.5 # -> [1, seq_len, 768]
k = self.key(xk) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
v = self.value(xv) # Linear [1, seq_len, 768] -> [1, seq_len, 768]
if self.layer_id == 0:
v_first = v # store the v of the first layer
else:
xv = (xv @ self.v1) @ self.v2
xv = xv + self.v0
xv = torch.sigmoid(xv)
v = v + (v_first - v) * xv # add value residual # -> [1, seq_len, 768]
xa = (xa @ self.a1) @ self.a2
xa = xa + self.a0
a = torch.sigmoid(xa) # -> [1, seq_len, 768]
xg = xg @ self.g1
xg = torch.sigmoid(xg)
g = xg @ self.g2 # -> [1, seq_len, 768]
kk = k * self.k_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # -> [1, seq_len, 768]
k = k * (1 + (a - 1) * self.k_a) # -> [1, seq_len, 768]
# start op
a_op = -kk
b_op = kk * a
B, T, C = r.size() # 768
H = C // HS # 12
r_op = r.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
k_op = k.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64]
v_op = v.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
a_op = a_op.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1]
b_op = b_op.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64]
w_op = w.view(B, T, H, HS).float()
w_op = torch.exp(-torch.exp(w_op)) # -> [1, seq_len, 12, 64]
w_op = w_op.view(B, T, H, 1, HS) # -> [1, seq_len, 12, 1, 64]
out = torch.zeros((B, T, H, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
state = torch.zeros((B, H, HS, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64]
for t in range(T):
rr_op = r_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
kk_op = k_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64]
vv_op = v_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
aa_op = a_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1]
bb_op = b_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64]
ww_op = w_op[:, t, :] # [1, seq_len, 12, 64] -> [1, 12, 1, 64]
state = state * ww_op + state @ aa_op @ bb_op + vv_op @ kk_op # -> [1, 12, 64, 64]
out[:, t, :] = (state @ rr_op).view(B, H, HS) # -> [1, seq_len, 12, 64]
x = out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768]
# end op
x = self.ln_x(x.view(B * T, C)).view(B, T, C) # -> [1, seq_len, 768]
xx = r.view(B, T, H, -1) * k.view(B, T, H, -1) # -> [1, seq_len, 12, 64]
xx = xx * self.r_k # -> [1, seq_len, 12, 64]
xx = xx.sum(dim=-1, keepdim=True) # -> [1, seq_len, 12, 1]
xx = xx * v.view(B, T, H, -1) # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64]
xx = xx.view(B, T, C) # -> [1, seq_len, 768]
x = x + xx # -> [1, seq_len, 768]
x = self.output(x * g) # Linear -> [1, seq_len, 768]
return x, v_first
class RWKV_CMix_x070(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
# with torch.no_grad():
self.x_k = nn.Parameter(torch.empty(1, 1, n_embd))
self.key = nn.Linear(n_embd, dim_ffn, bias=False)
self.value = nn.Linear(dim_ffn, n_embd, bias=False)
def forward(self, x):
shift = torch.zeros_like(x)
shift[:, 1:, :] = x[:, :-1, :]
xx = shift - x # time_shift -> [1, seq_len, 768]
k = x + xx * self.x_k # -> [1, seq_len, 768]
k = torch.relu(self.key(k)) ** 2 # Linear -> [1, seq_len, 768]
return self.value(k) # Linear -> [1, seq_len, 768]
class Block(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.layer_id = layer_id
self.ln0 = nn.LayerNorm(n_embd) # only used in block 0, should be fused with emb
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
self.att = RWKV_Tmix_x070(layer_id)
self.ffn = RWKV_CMix_x070(layer_id)
def forward(self, x, v_first):
if self.layer_id == 0:
x = self.ln0(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β
ln = self.ln1(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β
xx, v_first = self.att(ln, v_first) # [1, seq_len, 768] -> [1, seq_len, 768]
x = x + xx # [1, seq_len, 768] -> [1, seq_len, 768]
x = x + self.ffn(self.ln2(x)) # [1, seq_len, 768] -> [1, seq_len, 768]
return x, v_first
class RWKV(nn.Module):
def __init__(self):
super().__init__()
self.emb = nn.Embedding(vocab_size, n_embd)
self.blocks = nn.ModuleList([Block(i) for i in range(n_layer)])
self.ln_out = nn.LayerNorm(n_embd)
self.head = nn.Linear(n_embd, vocab_size, bias=False)
def forward(self, idx):
x = self.emb(idx) # [1, seq_len] -> [1, seq_len, 768]
v_first = torch.empty_like(x) # -> [1, seq_len, 768]
for block in self.blocks:
x, v_first = block(x, v_first)
x = self.ln_out(x) # [1, seq_len, 768] -> [1, seq_len, 768]
x = self.head(x) # [1, seq_len, 768] -> [1, seq_len, 65536]
return x
class RWKVLMHeadModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.rwkv = RWKV()
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hook_attention = None
for name, param in self.rwkv.named_parameters():
init.normal_(param.data)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
lm_logits = self.rwkv(input_ids)
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_labels = labels[..., 1:].contiguous().view(-1)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
mask = shift_labels < self.config.vocab_size
shift_labels = shift_labels[mask]
shift_logits = shift_logits[mask]
loss = CrossEntropyLoss()(shift_logits, shift_labels)
return lm_logits, loss