######################################################################################################## # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM ######################################################################################################## import torch, types, os, gc, math, json import numpy as np import torch.nn as nn from torch.nn import Module from torch.nn import functional as F np.set_printoptions(precision=4, suppress=True, linewidth=200) torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True torch._C._jit_set_autocast_mode(False) """ This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation) """ args = types.SimpleNamespace() # model download: https://huggingface.co/BlinkDL/rwkv-7-world MODEL_PATH = "/home/colin/.cache/modelscope/hub/Blink_DL/rwkv-7-world/RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth" # for 0.1B args.n_layer = 12 args.n_embd = 768 D_DECAY_LORA = 64 D_AAA_LORA = 64 D_MV_LORA = 32 D_GATE_LORA = 128 args.vocab_size = 65536 # DTYPE = torch.bfloat16 DTYPE = torch.half # better args.head_size_a = 64 # don't change HS = args.head_size_a ######################################################################################################## # RWKV Tokenizer (slow version) ######################################################################################################## class RWKV_TOKENIZER: table: list[list[list[bytes]]] good: list[set[int]] wlen: list[int] def __init__(self, file_name): self.idx2token = {} sorted = [] # must be already sorted lines = open(file_name, "r", encoding="utf-8").readlines() for l in lines: idx = int(l[: l.index(" ")]) x = eval(l[l.index(" ") : l.rindex(" ")]) x = x.encode("utf-8") if isinstance(x, str) else x assert isinstance(x, bytes) assert len(x) == int(l[l.rindex(" ") :]) sorted += [x] self.idx2token[idx] = x self.token2idx = {} for k, v in self.idx2token.items(): self.token2idx[v] = int(k) # precompute some tables for fast matching self.table = [[[] for j in range(256)] for i in range(256)] self.good = [set() for i in range(256)] self.wlen = [0 for i in range(256)] for i in reversed(range(len(sorted))): # reverse order - match longer tokens first s = sorted[i] if len(s) >= 2: s0 = int(s[0]) s1 = int(s[1]) self.table[s0][s1] += [s] self.wlen[s0] = max(self.wlen[s0], len(s)) self.good[s0].add(s1) def encodeBytes(self, src: bytes) -> list[int]: src_len: int = len(src) tokens: list[int] = [] i: int = 0 while i < src_len: s: bytes = src[i : i + 1] if i < src_len - 1: s1: int = int(src[i + 1]) s0: int = int(src[i]) if s1 in self.good[s0]: sss: bytes = src[i : i + self.wlen[s0]] try: s = next(filter(sss.startswith, self.table[s0][s1])) except: pass tokens.append(self.token2idx[s]) i += len(s) return tokens def decodeBytes(self, tokens): return b"".join(map(lambda i: self.idx2token[i], tokens)) def encode(self, src: str): return self.encodeBytes(src.encode("utf-8")) def decode(self, tokens): return self.decodeBytes(tokens).decode("utf-8") def printTokens(self, tokens): for i in tokens: s = self.idx2token[i] try: s = s.decode("utf-8") except: pass print(f"{repr(s)}{i}", end=" ") # print(repr(s), i) print() tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt") ######################################################################################################## # RWKV TimeMix ######################################################################################################## class RWKV_Tmix_x070(Module): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.head_size = args.head_size_a self.n_head = args.dim_att // self.head_size assert args.dim_att % self.n_head == 0 H = self.n_head HS = self.head_size C = args.n_embd self.x_r = nn.Parameter(torch.empty(1, 1, C)) self.x_w = nn.Parameter(torch.empty(1, 1, C)) self.x_k = nn.Parameter(torch.empty(1, 1, C)) self.x_v = nn.Parameter(torch.empty(1, 1, C)) self.x_a = nn.Parameter(torch.empty(1, 1, C)) self.x_g = nn.Parameter(torch.empty(1, 1, C)) self.w0 = nn.Parameter(torch.empty(1, 1, C)) self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA)) self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C)) self.a0 = nn.Parameter(torch.empty(1, 1, C)) self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA)) self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C)) self.v0 = nn.Parameter(torch.empty(1, 1, C)) self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA)) self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C)) self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA)) self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C)) self.k_k = nn.Parameter(torch.empty(1, 1, C)) self.k_a = nn.Parameter(torch.empty(1, 1, C)) self.r_k = nn.Parameter(torch.empty(H, HS)) self.receptance = nn.Linear(C, C, bias=False) self.key = nn.Linear(C, C, bias=False) self.value = nn.Linear(C, C, bias=False) self.output = nn.Linear(C, C, bias=False) self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!! def forward(self, x, v_first): B, T, C = x.size() # seq_len H = self.n_head # 12 xx = torch.zeros_like(x) # time_shift [1, seq_len, 768] -> [1, seq_len, 768] xx[:, 0, :] = -x[:, 0, :] xx[:, 1:, :] = x[:, :-1, :] - x[:, 1:, :] xr = x + xx * self.x_r # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] xw = x + xx * self.x_w # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] xk = x + xx * self.x_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] xv = x + xx * self.x_v # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] xa = x + xx * self.x_a # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] xg = x + xx * self.x_g # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] r = self.receptance(xr) # Linear [1, seq_len, 768] -> [1, seq_len, 768] xw = torch.tanh(xw @ self.w1) # -> [1, seq_len, 64] xw = xw @ self.w2 + self.w0 # -> [1, seq_len, 768] xw = -F.softplus(-xw) # 函数的输出范围为 [0,∞) w = xw - 0.5 # -> [1, seq_len, 768] k = self.key(xk) # Linear [1, seq_len, 768] -> [1, seq_len, 768] v = self.value(xv) # Linear [1, seq_len, 768] -> [1, seq_len, 768] if self.layer_id == 0: v_first = v # store the v of the first layer else: xv = (xv @ self.v1) @ self.v2 xv = xv + self.v0 xv = torch.sigmoid(xv) v = v + (v_first - v) * xv # add value residual # -> [1, seq_len, 768] xa = (xa @ self.a1) @ self.a2 xa = xa + self.a0 a = torch.sigmoid(xa) # -> [1, seq_len, 768] xg = xg @ self.g1 xg = torch.sigmoid(xg) g = xg @ self.g2 # -> [1, seq_len, 768] kk = k * self.k_k # [1, seq_len, 768] * [1, 1, 768] -> [1, seq_len, 768] kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C) # -> [1, seq_len, 768] k = k * (1 + (a - 1) * self.k_a) # -> [1, seq_len, 768] # start op a_op = -kk b_op = kk * a B, T, C = r.size() # 768 H = C // HS # 12 r_op = r.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1] k_op = k.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64] v_op = v.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1] a_op = a_op.view(B, T, H, HS, 1).float() # -> [1, seq_len, 12, 64, 1] b_op = b_op.view(B, T, H, 1, HS).float() # -> [1, seq_len, 12, 1, 64] w_op = w.view(B, T, H, HS).float() w_op = torch.exp(-torch.exp(w_op)) # -> [1, seq_len, 12, 64] w_op = w_op.view(B, T, H, 1, HS) # -> [1, seq_len, 12, 1, 64] out = torch.zeros((B, T, H, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64] state = torch.zeros((B, H, HS, HS), device=r_op.device, dtype=torch.float) # -> [1, seq_len, 12, 64] for t in range(T): rr_op = r_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1] kk_op = k_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64] vv_op = v_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1] aa_op = a_op[:, t, :] # [1, seq_len, 12, 64, 1] -> [1, 12, 64, 1] bb_op = b_op[:, t, :] # [1, seq_len, 12, 1, 64] -> [1, 12, 1, 64] ww_op = w_op[:, t, :] # [1, seq_len, 12, 64] -> [1, 12, 1, 64] state = state * ww_op + state @ aa_op @ bb_op + vv_op @ kk_op # -> [1, 12, 64, 64] out[:, t, :] = (state @ rr_op).view(B, H, HS) # -> [1, seq_len, 12, 64] x = out.view(B, T, C).to(dtype=DTYPE) # -> [1, seq_len, 768] # end op x = self.ln_x(x.view(B * T, C)).view(B, T, C) # -> [1, seq_len, 768] xx = r.view(B, T, H, -1) * k.view(B, T, H, -1) # -> [1, seq_len, 12, 64] xx = xx * self.r_k # -> [1, seq_len, 12, 64] xx = xx.sum(dim=-1, keepdim=True) # -> [1, seq_len, 12, 1] xx = xx * v.view(B, T, H, -1) # [1, seq_len, 12, 1] x [1, seq_len, 12, 64] -> [1, seq_len, 12, 64] xx = xx.view(B, T, C) # -> [1, seq_len, 768] x = x + xx # -> [1, seq_len, 768] x = self.output(x * g) # Linear -> [1, seq_len, 768] return x, v_first ######################################################################################################## # RWKV ChannelMix ######################################################################################################## class RWKV_CMix_x070(Module): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id with torch.no_grad(): self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd)) self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) def forward(self, x): shift = torch.zeros_like(x) shift[:, 1:, :] = x[:, :-1, :] xx = shift - x # time_shift -> [1, seq_len, 768] k = x + xx * self.x_k # -> [1, seq_len, 768] k = torch.relu(self.key(k)) ** 2 # Linear -> [1, seq_len, 768] return self.value(k) # Linear -> [1, seq_len, 768] ######################################################################################################## # RWKV Block ######################################################################################################## class Block(Module): def __init__(self, args, layer_id): super().__init__() self.args = args self.layer_id = layer_id self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb self.ln1 = nn.LayerNorm(args.n_embd) self.ln2 = nn.LayerNorm(args.n_embd) self.att = RWKV_Tmix_x070(args, layer_id) self.ffn = RWKV_CMix_x070(args, layer_id) def forward(self, x, v_first): if self.layer_id == 0: x = self.ln0(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β ln = self.ln1(x) # LayerNorm -> [1, seq_len, 768] normal at dim 768 * γ + β xx, v_first = self.att(ln, v_first) # [1, seq_len, 768] -> [1, seq_len, 768] x = x + xx # [1, seq_len, 768] -> [1, seq_len, 768] x = x + self.ffn(self.ln2(x)) # [1, seq_len, 768] -> [1, seq_len, 768] return x, v_first ######################################################################################################## # RWKV Model ######################################################################################################## class RWKV(nn.Module): def __init__(self, args): super().__init__() args.dim_att = args.n_embd args.dim_ffn = args.n_embd * 4 self.emb = nn.Embedding(args.vocab_size, args.n_embd) self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) self.ln_out = nn.LayerNorm(args.n_embd) self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) def forward(self, idx): x = self.emb(idx) # [1, seq_len] -> [1, seq_len, 768] v_first = torch.empty_like(x) # -> [1, seq_len, 768] for block in self.blocks: x, v_first = block(x, v_first) x = self.ln_out(x) # [1, seq_len, 768] -> [1, seq_len, 768] x = self.head(x) # [1, seq_len, 768] -> [1, seq_len, 65536] return x ######################################################################################################## # RWKV Inference ######################################################################################################## model_params = torch.load(MODEL_PATH, map_location="cpu") with torch.no_grad(): model = RWKV(args).to(dtype=DTYPE).cuda() model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2 ######################################################################################################## prompt = "中国的首都是在" input = tokenizer.encode(prompt) print(f"\nInput:\n{input}") # 中国的首都是在 # 北 [probability 4.99%] # 中 [probability 4.22%] # 这 [probability 3.38%] # 上 [probability 2.74%] # 东 [probability 2.28%] # 台 [probability 2.23%] # 南 [probability 1.86%] # 广 [probability 1.83%] # 华 [probability 1.63%] # 河 [probability 1.47%] out = model.forward(torch.tensor(input).reshape(1, -1).cuda()) print(f"\nOutput:\n{out}") # logits of the last token => prediction for the next token out = out[0, -1] probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate) print(f"\n{prompt}") _, indices = torch.topk(probs, 10) # print top-10 possibilities for i in range(len(indices)): token_id = indices[i].item() token = tokenizer.decode([token_id]) token_prob = probs[token_id].item() print(token, f"[probability {token_prob:.2%}]") ######################################################################################################## with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f: todo = [json.loads(line) for line in f] todo = [[doc["text"].rsplit(" ", 1)[0], " " + doc["text"].rsplit(" ", 1)[1]] for doc in todo] print("\nCheck LAMBADA...") xsum = 0 xcnt = 0 xacc = 0 for d in todo: src = [0] + tokenizer.encode(d[0]) dst = tokenizer.encode(d[1]) logits = 0 correct = True out = model.forward(torch.tensor(src + dst).reshape(1, -1).cuda()) for i in range(len(dst)): ooo = out[0, len(src) - 1 + i].float() probs = F.softmax(ooo, dim=-1) logits += math.log(probs[dst[i]]) if torch.argmax(probs).item() != dst[i]: correct = False xcnt += 1 xsum += logits xacc += 1 if correct else 0 if xcnt % 100 == 0 or xcnt == len(todo): print(xcnt, "ppl", round(math.exp(-xsum / xcnt), 2), "acc", round(xacc / xcnt * 100, 2))