Refine rwkv.
This commit is contained in:
		
							parent
							
								
									240858c030
								
							
						
					
					
						commit
						821e7206b8
					
				| 
						 | 
				
			
			@ -14,5 +14,4 @@ K、V 就是等同于Transformer的Key与Value。
 | 
			
		|||
TimeMix,指的是过去信息x-1与当前信息x的混合。 xx = self.time_shift(x) - x 这个是典型的操作
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
nn.Embedding
 | 
			
		||||
RWKV_Tmix_x070
 | 
			
		||||
| 
						 | 
				
			
			@ -173,7 +173,6 @@ class RWKV_Tmix_x070(Module):
 | 
			
		|||
        self.k_a = nn.Parameter(torch.empty(1, 1, C))
 | 
			
		||||
        self.r_k = nn.Parameter(torch.empty(H, N))
 | 
			
		||||
 | 
			
		||||
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
			
		||||
        self.receptance = nn.Linear(C, C, bias=False)
 | 
			
		||||
        self.key = nn.Linear(C, C, bias=False)
 | 
			
		||||
        self.value = nn.Linear(C, C, bias=False)
 | 
			
		||||
| 
						 | 
				
			
			@ -181,64 +180,73 @@ class RWKV_Tmix_x070(Module):
 | 
			
		|||
        self.ln_x = nn.GroupNorm(H, C, eps=64e-5)  # !!! notice eps value !!!
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, v_first):
 | 
			
		||||
        B, T, C = x.size()
 | 
			
		||||
        H = self.n_head
 | 
			
		||||
        xx = self.time_shift(x) - x
 | 
			
		||||
        B, T, C = x.size()  # seq_len
 | 
			
		||||
        H = self.n_head  # 12
 | 
			
		||||
 | 
			
		||||
        xr = x + xx * self.x_r
 | 
			
		||||
        xw = x + xx * self.x_w
 | 
			
		||||
        xk = x + xx * self.x_k
 | 
			
		||||
        xv = x + xx * self.x_v
 | 
			
		||||
        xa = x + xx * self.x_a
 | 
			
		||||
        xg = x + xx * self.x_g
 | 
			
		||||
        shift = torch.zeros_like(x)
 | 
			
		||||
        shift[:, 1:, :] = x[:, :-1, :]
 | 
			
		||||
 | 
			
		||||
        r = self.receptance(xr)
 | 
			
		||||
        w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5  # soft-clamp to (-inf, -0.5)
 | 
			
		||||
        k = self.key(xk)
 | 
			
		||||
        v = self.value(xv)
 | 
			
		||||
        xx = shift - x  # time_shift  [1, seq_len, 768] -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        xr = x + xx * self.x_r  # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
 | 
			
		||||
        xw = x + xx * self.x_w  # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
 | 
			
		||||
        xk = x + xx * self.x_k  # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
 | 
			
		||||
        xv = x + xx * self.x_v  # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
 | 
			
		||||
        xa = x + xx * self.x_a  # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
 | 
			
		||||
        xg = x + xx * self.x_g  # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        r = self.receptance(xr)  # [1, seq_len, 768] -> [1, seq_len, 768]
 | 
			
		||||
        # [1, seq_len, 768] -> [1, seq_len, 768]
 | 
			
		||||
        w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5  # softplus clamp to (-inf, -0.5)
 | 
			
		||||
        k = self.key(xk)  # [1, seq_len, 768] -> [1, seq_len, 768]
 | 
			
		||||
        v = self.value(xv)  # [1, seq_len, 768] -> [1, seq_len, 768]
 | 
			
		||||
        if self.layer_id == 0:
 | 
			
		||||
            v_first = v  # store the v of the first layer
 | 
			
		||||
        else:
 | 
			
		||||
            #  -> [1, seq_len, 768]
 | 
			
		||||
            v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)  # add value residual
 | 
			
		||||
        a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)  # a is "in-context learning rate"
 | 
			
		||||
        g = torch.sigmoid(xg @ self.g1) @ self.g2
 | 
			
		||||
        a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2)  # -> [1, seq_len, 768] # a is "in-context learning rate"
 | 
			
		||||
        g = torch.sigmoid(xg @ self.g1) @ self.g2  # -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        kk = k * self.k_k
 | 
			
		||||
        kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)
 | 
			
		||||
        k = k * (1 + (a - 1) * self.k_a)
 | 
			
		||||
        kk = k * self.k_k  # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768]
 | 
			
		||||
        kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)  # -> [1, seq_len, 768]
 | 
			
		||||
        k = k * (1 + (a - 1) * self.k_a)  # -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        def RWKV7_OP(r, w, k, v, a, b):
 | 
			
		||||
            B, T, C = r.size()
 | 
			
		||||
            H = C // HEAD_SIZE
 | 
			
		||||
            N = HEAD_SIZE
 | 
			
		||||
            r = r.view(B, T, H, N).float()
 | 
			
		||||
            k = k.view(B, T, H, N).float()
 | 
			
		||||
            v = v.view(B, T, H, N).float()
 | 
			
		||||
            a = a.view(B, T, H, N).float()
 | 
			
		||||
            b = b.view(B, T, H, N).float()
 | 
			
		||||
            w = torch.exp(-torch.exp(w.view(B, T, H, N).float()))
 | 
			
		||||
            out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float)
 | 
			
		||||
            state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float)
 | 
			
		||||
            B, T, C = r.size()  # 768
 | 
			
		||||
            H = C // HEAD_SIZE  # 12
 | 
			
		||||
            N = HEAD_SIZE  # 64
 | 
			
		||||
            r = r.view(B, T, H, N).float()  # -> [1, seq_len, 12, 64]
 | 
			
		||||
            k = k.view(B, T, H, N).float()  # -> [1, seq_len, 12, 64]
 | 
			
		||||
            v = v.view(B, T, H, N).float()  # -> [1, seq_len, 12, 64]
 | 
			
		||||
            a = a.view(B, T, H, N).float()  # -> [1, seq_len, 12, 64]
 | 
			
		||||
            b = b.view(B, T, H, N).float()  # -> [1, seq_len, 12, 64]
 | 
			
		||||
            w = torch.exp(-torch.exp(w.view(B, T, H, N).float()))  # -> [1, seq_len, 12, 64]
 | 
			
		||||
            out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float)  # -> [1, seq_len, 12, 64]
 | 
			
		||||
            state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float)  # -> [1, seq_len, 12, 64]
 | 
			
		||||
 | 
			
		||||
            for t in range(T):
 | 
			
		||||
                kk = k[:, t, :].view(B, H, 1, N)
 | 
			
		||||
                rr = r[:, t, :].view(B, H, N, 1)
 | 
			
		||||
                vv = v[:, t, :].view(B, H, N, 1)
 | 
			
		||||
                aa = a[:, t, :].view(B, H, N, 1)
 | 
			
		||||
                bb = b[:, t, :].view(B, H, 1, N)
 | 
			
		||||
                state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk
 | 
			
		||||
                out[:, t, :] = (state @ rr).view(B, H, N)
 | 
			
		||||
                kk = k[:, t, :].view(B, H, 1, N)  # -> [1, 12, 1, 64]
 | 
			
		||||
                rr = r[:, t, :].view(B, H, N, 1)  # -> [1, 12, 64, 1]
 | 
			
		||||
                vv = v[:, t, :].view(B, H, N, 1)  # -> [1, 12, 64, 1]
 | 
			
		||||
                aa = a[:, t, :].view(B, H, N, 1)  # -> [1, 12, 64, 1]
 | 
			
		||||
                bb = b[:, t, :].view(B, H, 1, N)  # -> [1, 12, 1, 64]
 | 
			
		||||
                state = state * w[:, t, :, None, :] + state @ aa @ bb + vv @ kk  # -> [1, 12, 64, 64]
 | 
			
		||||
                out[:, t, :] = (state @ rr).view(B, H, N)  # -> [1, seq_len, 12, 64]
 | 
			
		||||
 | 
			
		||||
            return out.view(B, T, C).to(dtype=DTYPE)
 | 
			
		||||
            return out.view(B, T, C).to(dtype=DTYPE)  # -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        x = RWKV7_OP(r, w, k, v, -kk, kk * a)
 | 
			
		||||
        x = RWKV7_OP(r, w, k, v, -kk, kk * a)  # -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        x = self.ln_x(x.view(B * T, C)).view(B, T, C)
 | 
			
		||||
        x = self.ln_x(x.view(B * T, C)).view(B, T, C)  # -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        x = x + (
 | 
			
		||||
            (r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(dim=-1, keepdim=True) * v.view(B, T, H, -1)
 | 
			
		||||
        ).view(B, T, C)
 | 
			
		||||
        x = self.output(x * g)
 | 
			
		||||
        xx = r.view(B, T, H, -1) * k.view(B, T, H, -1)  # -> [1, seq_len, 12, 64]
 | 
			
		||||
        xx = xx * self.r_k  # -> [1, seq_len, 12, 64]
 | 
			
		||||
        xx = xx.sum(dim=-1, keepdim=True)  # -> [1, seq_len, 12, 1]
 | 
			
		||||
        xx = xx * v.view(B, T, H, -1)  # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64]
 | 
			
		||||
        xx = xx.view(B, T, C)  # -> [1, seq_len, 768]
 | 
			
		||||
        x = x + xx  # -> [1, seq_len, 768]
 | 
			
		||||
        x = self.output(x * g)  # -> [1, seq_len, 768]
 | 
			
		||||
        return x, v_first
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -252,7 +260,6 @@ class RWKV_CMix_x070(Module):
 | 
			
		|||
        super().__init__()
 | 
			
		||||
        self.args = args
 | 
			
		||||
        self.layer_id = layer_id
 | 
			
		||||
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd))
 | 
			
		||||
| 
						 | 
				
			
			@ -261,11 +268,15 @@ class RWKV_CMix_x070(Module):
 | 
			
		|||
        self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x):
 | 
			
		||||
        xx = self.time_shift(x) - x
 | 
			
		||||
 | 
			
		||||
        k = x + xx * self.x_k
 | 
			
		||||
        k = torch.relu(self.key(k)) ** 2
 | 
			
		||||
        return self.value(k)
 | 
			
		||||
        shift = torch.zeros_like(x)
 | 
			
		||||
        shift[:, 1:, :] = x[:, :-1, :]
 | 
			
		||||
 | 
			
		||||
        xx = shift - x  # time_shift -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        k = x + xx * self.x_k  # -> [1, seq_len, 768]
 | 
			
		||||
        k = torch.relu(self.key(k)) ** 2  # -> [1, seq_len, 768]
 | 
			
		||||
        return self.value(k)  # -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
########################################################################################################
 | 
			
		||||
| 
						 | 
				
			
			@ -289,11 +300,12 @@ class Block(Module):
 | 
			
		|||
    def forward(self, x, v_first):
 | 
			
		||||
 | 
			
		||||
        if self.layer_id == 0:
 | 
			
		||||
            x = self.ln0(x)
 | 
			
		||||
            x = self.ln0(x)  # -> [1, seq_len, 768]   normal at dim 768  * γ + β
 | 
			
		||||
        ln = self.ln1(x)  # -> [1, seq_len, 768]   normal at dim 768  * γ + β
 | 
			
		||||
 | 
			
		||||
        xx, v_first = self.att(self.ln1(x), v_first)
 | 
			
		||||
        x = x + xx
 | 
			
		||||
        x = x + self.ffn(self.ln2(x))
 | 
			
		||||
        xx, v_first = self.att(ln, v_first)  # [1, seq_len, 768] -> [1, seq_len, 768]
 | 
			
		||||
        x = x + xx  # [1, seq_len, 768] -> [1, seq_len, 768]
 | 
			
		||||
        x = x + self.ffn(self.ln2(x))  # [1, seq_len, 768] -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        return x, v_first
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -317,14 +329,14 @@ class RWKV(nn.Module):
 | 
			
		|||
 | 
			
		||||
    def forward(self, idx):
 | 
			
		||||
 | 
			
		||||
        x = self.emb(idx)
 | 
			
		||||
        x = self.emb(idx)  # [1, seq_len] -> [1, seq_len, 768]
 | 
			
		||||
 | 
			
		||||
        v_first = torch.empty_like(x)
 | 
			
		||||
        v_first = torch.empty_like(x)  # -> [1, seq_len, 768]
 | 
			
		||||
        for block in self.blocks:
 | 
			
		||||
            x, v_first = block(x, v_first)
 | 
			
		||||
 | 
			
		||||
        x = self.ln_out(x)
 | 
			
		||||
        x = self.head(x)
 | 
			
		||||
        x = self.ln_out(x)  # [1, seq_len, 768] -> [1, seq_len, 768]
 | 
			
		||||
        x = self.head(x)  # [1, seq_len, 768] -> [1, seq_len, 65536]
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -346,6 +358,19 @@ with torch.no_grad():
 | 
			
		|||
    input = tokenizer.encode(prompt)
 | 
			
		||||
    print(f"\nInput:\n{input}")
 | 
			
		||||
 | 
			
		||||
# 中国的首都是在
 | 
			
		||||
# 北 [probability 4.99%]
 | 
			
		||||
# 中 [probability 4.22%]
 | 
			
		||||
# 这 [probability 3.38%]
 | 
			
		||||
# 上 [probability 2.74%]
 | 
			
		||||
# 东 [probability 2.28%]
 | 
			
		||||
# 台 [probability 2.23%]
 | 
			
		||||
# 南 [probability 1.86%]
 | 
			
		||||
# 广 [probability 1.83%]
 | 
			
		||||
# 华 [probability 1.63%]
 | 
			
		||||
# 河 [probability 1.47%]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    out = model.forward(torch.tensor(input).reshape(1, -1).cuda())
 | 
			
		||||
    print(f"\nOutput:\n{out}")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,114 @@
 | 
			
		|||
########################################################################################################
 | 
			
		||||
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
 | 
			
		||||
########################################################################################################
 | 
			
		||||
# RWKV-7 in numpy, by https://github.com/johanwind
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from torch import load as torch_load
 | 
			
		||||
 | 
			
		||||
layer_norm = lambda x, w, b: (x - x.mean()) / (x.var() + 1e-5) ** 0.5 * w + b
 | 
			
		||||
group_norm = (
 | 
			
		||||
    lambda x, w, b: ((x - x.mean(axis=1, keepdims=1)) / (x.var(axis=1, keepdims=1) + 64e-5) ** 0.5).flatten() * w + b
 | 
			
		||||
)
 | 
			
		||||
sigmoid = lambda x: 1 / (1 + np.exp(-x))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def time_mixing(x, v0, last_x, S, params):
 | 
			
		||||
    mr, mw, mk, mv, ma, mg, w_bias, r_k, Ww1, Ww2, Wa1, Wa2, a_bias, Wg1, Wg2 = params[:15]
 | 
			
		||||
    k_k, k_a, Wr, Wk, Wv, Wo, ln_w, ln_b = params[-8:]
 | 
			
		||||
 | 
			
		||||
    xr, xw, xk, xv, xa, xg = [x + m * (last_x - x) for m in [mr, mw, mk, mv, ma, mg]]
 | 
			
		||||
 | 
			
		||||
    r = Wr @ xr
 | 
			
		||||
    w = np.exp(-sigmoid(np.tanh(xw @ Ww1) @ Ww2 + w_bias) / np.e**0.5)
 | 
			
		||||
    k = Wk @ xk
 | 
			
		||||
    v = Wv @ xv
 | 
			
		||||
    if v0 is None:
 | 
			
		||||
        v0 = v
 | 
			
		||||
    else:
 | 
			
		||||
        Wv2, Wv1, v_bias = params[15:18]
 | 
			
		||||
        v += (v0 - v) * sigmoid(xv @ Wv1 @ Wv2 + v_bias)
 | 
			
		||||
    a = sigmoid(xa @ Wa1 @ Wa2 + a_bias)
 | 
			
		||||
    g = sigmoid(xg @ Wg1) @ Wg2
 | 
			
		||||
    kk = k * k_k
 | 
			
		||||
    k += k * (a - 1) * k_a
 | 
			
		||||
 | 
			
		||||
    r, w, k, v, kk, a, r_k = [i.reshape(N_HEAD, HEAD_SIZE, 1) for i in [r, w, k, v, kk, a, r_k]]
 | 
			
		||||
    kk /= np.maximum(np.linalg.norm(kk, axis=1, keepdims=1), 1e-12)
 | 
			
		||||
 | 
			
		||||
    S = S * w.mT - S @ kk * (kk * a).mT + v * k.mT
 | 
			
		||||
    y = S @ r
 | 
			
		||||
 | 
			
		||||
    y = group_norm(y, ln_w, ln_b)
 | 
			
		||||
    y += ((r * k * r_k).sum(axis=1, keepdims=1) * v).flatten()
 | 
			
		||||
    return Wo @ (y * g), v0, x, S
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def channel_mixing(x, last_x, mix, Wk, Wv):
 | 
			
		||||
    k = Wk @ (x + mix * (last_x - x))
 | 
			
		||||
    v = Wv @ np.maximum(k, 0) ** 2
 | 
			
		||||
    return v, x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def RWKV7(params, token, state):
 | 
			
		||||
    x = params("emb")[0][token]
 | 
			
		||||
    x = layer_norm(x, *params("blocks.0.ln0"))
 | 
			
		||||
 | 
			
		||||
    v0 = None
 | 
			
		||||
    for i in range(N_LAYER):
 | 
			
		||||
        x_ = layer_norm(x, *params(f"blocks.{i}.ln1"))
 | 
			
		||||
        dx, v0, state[0][i, 0], state[1][i] = time_mixing(
 | 
			
		||||
            x_, v0, state[0][i, 0], state[1][i], params(f"blocks.{i}.att")
 | 
			
		||||
        )
 | 
			
		||||
        x = x + dx
 | 
			
		||||
 | 
			
		||||
        x_ = layer_norm(x, *params(f"blocks.{i}.ln2"))
 | 
			
		||||
        dx, state[0][i, 1] = channel_mixing(x_, state[0][i, 1], *params(f"blocks.{i}.ffn"))
 | 
			
		||||
        x = x + dx
 | 
			
		||||
 | 
			
		||||
    x = layer_norm(x, *params("ln_out"))
 | 
			
		||||
    logits = params("head")[0] @ x
 | 
			
		||||
 | 
			
		||||
    return logits, state
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Verification
 | 
			
		||||
 | 
			
		||||
# Available at https://huggingface.co/BlinkDL/rwkv-7-world/resolve/main/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth
 | 
			
		||||
MODEL_FILE = "/home/colin/.cache/modelscope/hub/Blink_DL/rwkv-7-world/RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
N_LAYER = 24
 | 
			
		||||
N_EMBD = 1024
 | 
			
		||||
HEAD_SIZE = 64
 | 
			
		||||
N_HEAD = N_EMBD // HEAD_SIZE
 | 
			
		||||
 | 
			
		||||
if 1:  # Reference implementation
 | 
			
		||||
    context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."
 | 
			
		||||
 | 
			
		||||
    # pip install rwkv
 | 
			
		||||
    import os
 | 
			
		||||
 | 
			
		||||
    os.environ["RWKV_V7_ON"] = "1"
 | 
			
		||||
    from rwkv.utils import PIPELINE
 | 
			
		||||
    from rwkv.model import RWKV as referenceRWKV
 | 
			
		||||
 | 
			
		||||
    model = referenceRWKV(model=MODEL_FILE[:-4], strategy="cpu fp32")
 | 
			
		||||
    pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
 | 
			
		||||
    tokens = pipeline.encode(context)
 | 
			
		||||
 | 
			
		||||
    reference_logits, state = model.forward(tokens, None)
 | 
			
		||||
    reference_logits = reference_logits.numpy()
 | 
			
		||||
 | 
			
		||||
weights = torch_load(MODEL_FILE, map_location="cpu", weights_only=True)
 | 
			
		||||
weights = {k: v.squeeze().float().numpy() for k, v in weights.items()}
 | 
			
		||||
params = lambda prefix: [weights[key] for key in weights.keys() if key.startswith(prefix)]
 | 
			
		||||
 | 
			
		||||
state = (
 | 
			
		||||
    np.zeros((N_LAYER, 2, N_EMBD), dtype=np.float32),
 | 
			
		||||
    np.zeros((N_LAYER, N_HEAD, HEAD_SIZE, HEAD_SIZE), dtype=np.float32),
 | 
			
		||||
)
 | 
			
		||||
for token in tokens:
 | 
			
		||||
    minimal_logits, state = RWKV7(params, token, state)
 | 
			
		||||
 | 
			
		||||
print("Deviation from official rwkv:", max(abs(minimal_logits - reference_logits)) / reference_logits.std())
 | 
			
		||||
| 
						 | 
				
			
			@ -1,5 +1,35 @@
 | 
			
		|||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
# 假设输入是一个 batch 的序列数据,形状为 (batch_size, seq_len, hidden_dim)
 | 
			
		||||
batch_size, seq_len, hidden_dim = 2, 2, 4
 | 
			
		||||
input_tensor = torch.randn(batch_size, seq_len, hidden_dim)
 | 
			
		||||
 | 
			
		||||
# 定义 LayerNorm 层
 | 
			
		||||
layer_norm1 = nn.LayerNorm(hidden_dim)
 | 
			
		||||
layer_norm2 = nn.LayerNorm(hidden_dim)
 | 
			
		||||
 | 
			
		||||
# 应用 LayerNorm
 | 
			
		||||
output = layer_norm1(input_tensor)
 | 
			
		||||
print(input_tensor.numpy())
 | 
			
		||||
print("\n")
 | 
			
		||||
print("\n")
 | 
			
		||||
print(output.detach().numpy())
 | 
			
		||||
 | 
			
		||||
output = layer_norm2(output)
 | 
			
		||||
print("\n")
 | 
			
		||||
print("\n")
 | 
			
		||||
print(output.detach().numpy())
 | 
			
		||||
 | 
			
		||||
x1 = torch.empty((1, 7, 768), dtype=float)
 | 
			
		||||
time_shift = nn.ZeroPad2d((1, 1, -1, 1))
 | 
			
		||||
xx = time_shift(x1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
x1 = torch.tensor([[1, 2]], dtype=float)
 | 
			
		||||
x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue