diff --git a/rwkv/RWKV-v7/RWKV.drawio b/rwkv/RWKV-v7/RWKV.drawio new file mode 100644 index 0000000..bcaed0f --- /dev/null +++ b/rwkv/RWKV-v7/RWKV.drawio @@ -0,0 +1,1026 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/rwkv/RWKV-v7/RWKV.png b/rwkv/RWKV-v7/RWKV.png new file mode 100644 index 0000000..d11d2c6 Binary files /dev/null and b/rwkv/RWKV-v7/RWKV.png differ diff --git a/rwkv/RWKV-v7/rwkv_v7_demo.py b/rwkv/RWKV-v7/rwkv_v7_demo.py index ea2524d..6dd37eb 100644 --- a/rwkv/RWKV-v7/rwkv_v7_demo.py +++ b/rwkv/RWKV-v7/rwkv_v7_demo.py @@ -40,7 +40,7 @@ args.vocab_size = 65536 DTYPE = torch.half # better args.head_size_a = 64 # don't change -HEAD_SIZE = args.head_size_a +HS = args.head_size_a ######################################################################################################## # RWKV Tokenizer (slow version) @@ -144,7 +144,7 @@ class RWKV_Tmix_x070(Module): assert args.dim_att % self.n_head == 0 H = self.n_head - N = self.head_size + HS = self.head_size C = args.n_embd self.x_r = nn.Parameter(torch.empty(1, 1, C)) @@ -171,7 +171,7 @@ class RWKV_Tmix_x070(Module): self.k_k = nn.Parameter(torch.empty(1, 1, C)) self.k_a = nn.Parameter(torch.empty(1, 1, C)) - self.r_k = nn.Parameter(torch.empty(H, N)) + self.r_k = nn.Parameter(torch.empty(H, HS)) self.receptance = nn.Linear(C, C, bias=False) self.key = nn.Linear(C, C, bias=False) @@ -183,10 +183,9 @@ class RWKV_Tmix_x070(Module): B, T, C = x.size() # seq_len H = self.n_head # 12 - shift = torch.zeros_like(x) - shift[:, 1:, :] = x[:, :-1, :] - - xx = shift - x # time_shift [1, seq_len, 768] -> [1, seq_len, 768] + xx = torch.zeros_like(x) # time_shift [1, seq_len, 768] -> [1, seq_len, 768] + xx[:, 0, :] = -x[:, 0, :] + xx[:, 1:, :] = x[:, :-1, :] - x[:, 1:, :] xr = x + xx * self.x_r # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] xw = x + xx * self.x_w # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] @@ -195,48 +194,71 @@ class RWKV_Tmix_x070(Module): xa = x + xx * self.x_a # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] xg = x + xx * self.x_g # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] - r = self.receptance(xr) # [1, seq_len, 768] -> [1, seq_len, 768] - # [1, seq_len, 768] -> [1, seq_len, 768] - w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # softplus clamp to (-inf, -0.5) - k = self.key(xk) # [1, seq_len, 768] -> [1, seq_len, 768] - v = self.value(xv) # [1, seq_len, 768] -> [1, seq_len, 768] + r = self.receptance(xr) # Linear [1, seq_len, 768] -> [1, seq_len, 768] + + xw = torch.tanh(xw @ self.w1) # -> [1, seq_len, 64] + xw = xw @ self.w2 + self.w0 # -> [1, seq_len, 768] + xw = -F.softplus(-xw) # 函数的输出范围为 [0,∞) + w = xw - 0.5 # -> [1, seq_len, 768] + + k = self.key(xk) # Linear [1, seq_len, 768] -> [1, seq_len, 768] + + v = self.value(xv) # Linear [1, seq_len, 768] -> [1, seq_len, 768] + if self.layer_id == 0: v_first = v # store the v of the first layer else: - # -> [1, seq_len, 768] - v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual - a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # -> [1, seq_len, 768] # a is "in-context learning rate" - g = torch.sigmoid(xg @ self.g1) @ self.g2 # -> [1, seq_len, 768] + xv = (xv @ self.v1) @ self.v2 + xv = xv + self.v0 + xv = torch.sigmoid(xv) + v = v + (v_first - v) * xv # add value residual # -> [1, seq_len, 768] + + xa = (xa @ self.a1) @ self.a2 + xa = xa + self.a0 + a = torch.sigmoid(xa) # -> [1, seq_len, 768] + + xg = xg @ self.g1 + xg = torch.sigmoid(xg) + g = xg @ self.g2 # -> [1, seq_len, 768] kk = k * self.k_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # -> [1, seq_len, 768] + k = k * (1 + (a - 1) * self.k_a) # -> [1, seq_len, 768] - def RWKV7_OP(r, w, k, v, a, b): - B, T, C = r.size() # 768 - H = C // HEAD_SIZE # 12 - N = HEAD_SIZE # 64 - r = r.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] - k = k.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] - v = v.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] - a = a.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] - b = b.view(B, T, H, N).float() # -> [1, seq_len, 12, 64] - w = torch.exp(-torch.exp(w.view(B, T, H, N).float())) # -> [1, seq_len, 12, 64] - out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64] - state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float) # -> [1, seq_len, 12, 64] + # start op + a_op = -kk + b_op = kk * a - for t in range(T): - kk = k[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64] - rr = r[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1] - vv = v[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1] - aa = a[:, t, :].view(B, H, N, 1) # -> [1, 12, 64, 1] - bb = b[:, t, :].view(B, H, 1, N) # -> [1, 12, 1, 64] - state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk # -> [1, 12, 64, 64] - out[:, t, :] = (state @ rr).view(B, H, N) # -> [1, seq_len, 12, 64] + B, T, C = r.size() # 768 + H = C // HS # 12 - return out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768] + r_op = r.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1] + k_op = k.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64] + v_op = v.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1] + a_op = a_op.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1] + b_op = b_op.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64] - x = RWKV7_OP(r, w, k, v, -kk, kk * a) # -> [1, seq_len, 768] + w_op = w.view(B, T, H, HS).float() + w_op = torch.exp(-torch.exp(w_op)) # -> [1, seq_len, 12, 64] + w_op = w_op.view(B, T, H, 1, HS) # -> [1, seq_len, 12, 1, 64] + + out = torch.zeros((B, T, H, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64] + state = torch.zeros((B, H, HS, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64] + + for t in range(T): + rr_op = r_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1] + kk_op = k_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64] + vv_op = v_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1] + aa_op = a_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1] + bb_op = b_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64] + ww_op = w_op[:, t, :] # [1, seq_len, 12, 64] -> [1, 12, 1, 64] + + state = state * ww_op + state @ aa_op @ bb_op + vv_op @ kk_op # -> [1, 12, 64, 64] + out[:, t, :] = (state @ rr_op).view(B, H, HS) # -> [1, seq_len, 12, 64] + + x = out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768] + # end op x = self.ln_x(x.view(B * T, C)).view(B, T, C) # -> [1, seq_len, 768] @@ -246,7 +268,7 @@ class RWKV_Tmix_x070(Module): xx = xx * v.view(B, T, H, -1) # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64] xx = xx.view(B, T, C) # -> [1, seq_len, 768] x = x + xx # -> [1, seq_len, 768] - x = self.output(x * g) # -> [1, seq_len, 768] + x = self.output(x * g) # Linear -> [1, seq_len, 768] return x, v_first @@ -275,8 +297,8 @@ class RWKV_CMix_x070(Module): xx = shift - x # time_shift -> [1, seq_len, 768] k = x + xx * self.x_k # -> [1, seq_len, 768] - k = torch.relu(self.key(k)) ** 2 # -> [1, seq_len, 768] - return self.value(k) # -> [1, seq_len, 768] + k = torch.relu(self.key(k)) ** 2 # Linear -> [1, seq_len, 768] + return self.value(k) # Linear -> [1, seq_len, 768] ######################################################################################################## @@ -300,8 +322,8 @@ class Block(Module): def forward(self, x, v_first): if self.layer_id == 0: - x = self.ln0(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β - ln = self.ln1(x) # -> [1, seq_len, 768] normal at dim 768 * γ + β + x = self.ln0(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β + ln = self.ln1(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β xx, v_first = self.att(ln, v_first) # [1, seq_len, 768] -> [1, seq_len, 768] x = x + xx # [1, seq_len, 768] -> [1, seq_len, 768] @@ -358,18 +380,17 @@ with torch.no_grad(): input = tokenizer.encode(prompt) print(f"\nInput:\n{input}") -# 中国的首都是在 -# 北 [probability 4.99%] -# 中 [probability 4.22%] -# 这 [probability 3.38%] -# 上 [probability 2.74%] -# 东 [probability 2.28%] -# 台 [probability 2.23%] -# 南 [probability 1.86%] -# 广 [probability 1.83%] -# 华 [probability 1.63%] -# 河 [probability 1.47%] - + # 中国的首都是在 + # 北 [probability 4.99%] + # 中 [probability 4.22%] + # 这 [probability 3.38%] + # 上 [probability 2.74%] + # 东 [probability 2.28%] + # 台 [probability 2.23%] + # 南 [probability 1.86%] + # 广 [probability 1.83%] + # 华 [probability 1.63%] + # 河 [probability 1.47%] out = model.forward(torch.tensor(input).reshape(1, -1).cuda()) print(f"\nOutput:\n{out}")