Compare commits

..

3 Commits

Author SHA1 Message Date
Colin 2e075552db Add hook of attention for query qkv. 2025-03-03 15:23:39 +08:00
Colin 3eea09d78c Add rwkv v7 demo. 2025-03-03 14:53:15 +08:00
Colin e3b63f4635 Refine model define. 2025-02-28 13:38:28 +08:00
15 changed files with 73666 additions and 333 deletions

55
rwkv/RWKV-v7/cuda/wkv7.cu Normal file
View File

@ -0,0 +1,55 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::Half bf16;
// typedef at::BFloat16 bf16;
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
F *__restrict__ const _y)
{
const int e = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
float state[_N_] = {0};
__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
for (int _t = 0; _t < T; _t++)
{
const int t = e*T*C + h*_N_ + i + _t * C;
__syncthreads();
r[i] = float(_r[t]);
w[i] = __expf(-__expf(float(_w[t])));
k[i] = float(_k[t]);
a[i] = float(_a[t]);
b[i] = float(_b[t]);
__syncthreads();
float sa = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
sa += a[j] * state[j];
}
float vv = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = state[j];
s = s * w[j] + k[j] * vv + sa * b[j];
y += s * r[j];
}
_y[t] = F(y);
}
}
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
{
assert(H*_N_ == C);
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, w, k, v, a, b, y);
}

View File

@ -0,0 +1,15 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::Half bf16;
// typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
}
TORCH_LIBRARY(wkv7, m) {
m.def("forward", forward);
}

View File

@ -0,0 +1,64 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
typedef at::Half bf16;
// typedef at::BFloat16 bf16;
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
F *__restrict__ const _y)
{
const int e = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
float state[_N_];
#pragma unroll
for (int j = 0; j < _N_; j++)
state[j] = _state[j];
__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
for (int _t = 0; _t < T; _t++)
{
const int t = e*T*C + h*_N_ + i + _t * C;
__syncthreads();
r[i] = float(_r[t]);
w[i] = __expf(-__expf(float(_w[t])));
k[i] = float(_k[t]);
a[i] = float(_a[t]);
b[i] = float(_b[t]);
__syncthreads();
float sa = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
sa += a[j] * state[j];
}
float vv = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = state[j];
s = s * w[j] + k[j] * vv + sa * b[j];
y += s * r[j];
}
_y[t] = F(y);
}
#pragma unroll
for (int j = 0; j < _N_; j++)
_state[j] = state[j];
}
void cuda_forward(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
{
assert(H*_N_ == C);
assert(B == 1); // only for B=1
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
}

View File

@ -0,0 +1,15 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::Half bf16;
// typedef at::BFloat16 bf16;
void cuda_forward(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
cuda_forward(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
}
TORCH_LIBRARY(wkv7s, m) {
m.def("forward", forward);
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,436 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import torch, types, os, gc, math, json
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
np.set_printoptions(precision=4, suppress=True, linewidth=200)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
# torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch._C._jit_set_autocast_mode(False)
'''
This will load RWKV-7 "Goose" x070 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
'''
args = types.SimpleNamespace()
# model download: https://huggingface.co/BlinkDL/rwkv-7-world
MODEL_PATH = "/home/colin/.cache/modelscope/hub/Blink_DL/rwkv-7-world/RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth"
# for 0.1B
args.n_layer = 12
args.n_embd = 768
D_DECAY_LORA = 64
D_AAA_LORA = 64
D_MV_LORA = 32
D_GATE_LORA = 128
args.vocab_size = 65536
# DTYPE = torch.bfloat16
DTYPE = torch.half # better
args.head_size_a = 64 # don't change
HEAD_SIZE = args.head_size_a
USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
MyStatic = torch.jit.script
########################################################################################################
# RWKV Tokenizer (slow version)
########################################################################################################
class RWKV_TOKENIZER():
table: list[list[list[bytes]]]
good: list[set[int]]
wlen: list[int]
def __init__(self, file_name):
self.idx2token = {}
sorted = [] # must be already sorted
lines = open(file_name, "r", encoding="utf-8").readlines()
for l in lines:
idx = int(l[:l.index(' ')])
x = eval(l[l.index(' '):l.rindex(' ')])
x = x.encode("utf-8") if isinstance(x, str) else x
assert isinstance(x, bytes)
assert len(x) == int(l[l.rindex(' '):])
sorted += [x]
self.idx2token[idx] = x
self.token2idx = {}
for k, v in self.idx2token.items():
self.token2idx[v] = int(k)
# precompute some tables for fast matching
self.table = [[[] for j in range(256)] for i in range(256)]
self.good = [set() for i in range(256)]
self.wlen = [0 for i in range(256)]
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
s = sorted[i]
if len(s) >= 2:
s0 = int(s[0])
s1 = int(s[1])
self.table[s0][s1] += [s]
self.wlen[s0] = max(self.wlen[s0], len(s))
self.good[s0].add(s1)
def encodeBytes(self, src: bytes) -> list[int]:
src_len: int = len(src)
tokens: list[int] = []
i: int = 0
while i < src_len:
s: bytes = src[i : i + 1]
if i < src_len - 1:
s1: int = int(src[i + 1])
s0: int = int(src[i])
if s1 in self.good[s0]:
sss: bytes = src[i : i + self.wlen[s0]]
try:
s = next(filter(sss.startswith, self.table[s0][s1]))
except:
pass
tokens.append(self.token2idx[s])
i += len(s)
return tokens
def decodeBytes(self, tokens):
return b''.join(map(lambda i: self.idx2token[i], tokens))
def encode(self, src: str):
return self.encodeBytes(src.encode("utf-8"))
def decode(self, tokens):
return self.decodeBytes(tokens).decode('utf-8')
def printTokens(self, tokens):
for i in tokens:
s = self.idx2token[i]
try:
s = s.decode('utf-8')
except:
pass
print(f'{repr(s)}{i}', end=' ')
# print(repr(s), i)
print()
tokenizer = RWKV_TOKENIZER("rwkv_vocab_v20230424.txt")
########################################################################################################
# CUDA Kernel
########################################################################################################
if USE_CUDA_KERNEL:
from torch.utils.cpp_extension import load
load(name="wkv7", sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"], is_python_module=False,
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
class WKV_7(torch.autograd.Function):
@staticmethod
def forward(ctx, r, w, k, v, a, b):
with torch.no_grad():
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
assert HEAD_SIZE == C // H
assert r.dtype == DTYPE
assert w.dtype == DTYPE
assert k.dtype == DTYPE
assert v.dtype == DTYPE
assert a.dtype == DTYPE
assert b.dtype == DTYPE
assert r.is_contiguous()
assert w.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert a.is_contiguous()
assert b.is_contiguous()
y = torch.empty((B, T, C), device=k.device, dtype=DTYPE, memory_format=torch.contiguous_format)
torch.ops.wkv7.forward(B, T, C, H, r, w, k, v, a, b, y)
return y
def RWKV7_OP(r, w, k, v, a, b):
return WKV_7.apply(r, w, k, v, a, b)
else:
def RWKV7_OP(r, w, k, v, a, b):
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
r = r.view(B, T, H, N).float()
k = k.view(B, T, H, N).float()
v = v.view(B, T, H, N).float()
a = a.view(B, T, H, N).float()
b = b.view(B, T, H, N).float()
w = torch.exp(-torch.exp(w.view(B, T, H, N).float()))
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float)
state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float)
for t in range(T):
kk = k[:, t, :].view(B, H, 1, N)
rr = r[:, t, :].view(B, H, N, 1)
vv = v[:, t, :].view(B, H, N, 1)
aa = a[:, t, :].view(B, H, N, 1)
bb = b[:, t, :].view(B, H, 1, N)
state = state * w[: , t, :, None, :] + state @ aa @ bb + vv @ kk
out[:, t, :] = (state @ rr).view(B, H, N)
# another method using einsum
#
# kk = k[:, t, :]
# rr = r[:, t, :]
# vv = v[:, t, :]
# aa = a[:, t, :]
# bb = b[:, t, :]
# sab = torch.einsum('bhik,bhk,bhj->bhij', state, aa, bb)
# state = state * w[: , t, :, None, :] + sab + torch.einsum('bhj,bhi->bhij', kk, vv)
# out[:, t, :] = torch.einsum('bhj,bhij->bhi', rr, state)
return out.view(B, T, C).to(dtype=DTYPE)
########################################################################################################
# RWKV TimeMix
########################################################################################################
class RWKV_Tmix_x070(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.head_size = args.head_size_a
self.n_head = args.dim_att // self.head_size
assert args.dim_att % self.n_head == 0
H = self.n_head
N = self.head_size
C = args.n_embd
self.x_r = nn.Parameter(torch.empty(1,1,C))
self.x_w = nn.Parameter(torch.empty(1,1,C))
self.x_k = nn.Parameter(torch.empty(1,1,C))
self.x_v = nn.Parameter(torch.empty(1,1,C))
self.x_a = nn.Parameter(torch.empty(1,1,C))
self.x_g = nn.Parameter(torch.empty(1,1,C))
self.w0 = nn.Parameter(torch.empty(1,1,C))
self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))
self.a0 = nn.Parameter(torch.empty(1,1,C))
self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C))
self.v0 = nn.Parameter(torch.empty(1,1,C))
self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA))
self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C))
self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C))
self.k_k = nn.Parameter(torch.empty(1,1,C))
self.k_a = nn.Parameter(torch.empty(1,1,C))
self.r_k = nn.Parameter(torch.empty(H,N))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(C, C, bias=False)
self.key = nn.Linear(C, C, bias=False)
self.value = nn.Linear(C, C, bias=False)
self.output = nn.Linear(C, C, bias=False)
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!
@MyFunction
def forward(self, x, v_first):
B, T, C = x.size()
H = self.n_head
xx = self.time_shift(x) - x
xr = x + xx * self.x_r
xw = x + xx * self.x_w
xk = x + xx * self.x_k
xv = x + xx * self.x_v
xa = x + xx * self.x_a
xg = x + xx * self.x_g
r = self.receptance(xr)
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
k = self.key(xk)
v = self.value(xv)
if self.layer_id == 0:
v_first = v # store the v of the first layer
else:
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.g1) @ self.g2
kk = k * self.k_k
kk = F.normalize(kk.view(B,T,H,-1), dim=-1, p=2.0).view(B,T,C)
k = k * (1 + (a-1) * self.k_a)
x = RWKV7_OP(r, w, k, v, -kk, kk*a)
x = self.ln_x(x.view(B * T, C)).view(B, T, C)
x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
x = self.output(x * g)
return x, v_first
########################################################################################################
# RWKV ChannelMix
########################################################################################################
class RWKV_CMix_x070(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad():
self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd))
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x) - x
k = x + xx * self.x_k
k = torch.relu(self.key(k)) ** 2
return self.value(k)
########################################################################################################
# RWKV Block
########################################################################################################
class Block(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.ln0 = nn.LayerNorm(args.n_embd) # only used in block 0, should be fused with emb
self.ln1 = nn.LayerNorm(args.n_embd)
self.ln2 = nn.LayerNorm(args.n_embd)
self.att = RWKV_Tmix_x070(args, layer_id)
self.ffn = RWKV_CMix_x070(args, layer_id)
@MyFunction
def forward(self, x, v_first):
if self.layer_id == 0:
x = self.ln0(x)
xx, v_first = self.att(self.ln1(x), v_first)
x = x + xx
x = x + self.ffn(self.ln2(x))
return x, v_first
########################################################################################################
# RWKV Model
########################################################################################################
class RWKV(nn.Module):
def __init__(self, args):
super().__init__()
args.dim_att = args.n_embd
args.dim_ffn = args.n_embd * 4
self.emb = nn.Embedding(args.vocab_size, args.n_embd)
self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
self.ln_out = nn.LayerNorm(args.n_embd)
self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
def forward(self, idx):
x = self.emb(idx)
v_first = torch.empty_like(x)
for block in self.blocks:
x, v_first = block(x, v_first)
x = self.ln_out(x)
x = self.head(x)
return x
########################################################################################################
# RWKV Inference
########################################################################################################
model_params = torch.load(MODEL_PATH, map_location="cpu")
with torch.no_grad():
model = RWKV(args).to(dtype=DTYPE).cuda()
model.load_state_dict(model_params, strict=False) # we will ignore blocks.0.att.v0/v1/v2
########################################################################################################
prompt = "中国的首都是在"
input = tokenizer.encode(prompt)
print(f'\nInput:\n{input}')
out = model.forward(torch.tensor(input).reshape(1,-1).cuda())
print(f'\nOutput:\n{out}')
# logits of the last token => prediction for the next token
out = out[0, -1]
probs = F.softmax(out.float(), dim=-1) # compute softmax in float (more accurate)
print(f'\n{prompt}')
_, indices = torch.topk(probs, 10) # print top-10 possibilities
for i in range(len(indices)):
token_id = indices[i].item()
token = tokenizer.decode([token_id])
token_prob = probs[token_id].item()
print(token, f'[probability {token_prob:.2%}]')
########################################################################################################
with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f:
todo = [json.loads(line) for line in f]
todo = [[doc['text'].rsplit(' ', 1)[0], " " + doc['text'].rsplit(' ', 1)[1]] for doc in todo]
print('\nCheck LAMBADA...')
xsum = 0
xcnt = 0
xacc = 0
for d in todo:
src = [0] + tokenizer.encode(d[0])
dst = tokenizer.encode(d[1])
logits = 0
correct = True
out = model.forward(torch.tensor(src+dst).reshape(1,-1).cuda())
for i in range(len(dst)):
ooo = out[0,len(src)-1+i].float()
probs = F.softmax(ooo, dim=-1)
logits += math.log(probs[dst[i]])
if torch.argmax(probs).item() != dst[i]:
correct = False
xcnt += 1
xsum += logits
xacc += 1 if correct else 0
if xcnt % 100 == 0 or xcnt == len(todo):
print(xcnt, 'ppl', round(math.exp(-xsum / xcnt), 2), 'acc', round(xacc/xcnt*100, 2))

File diff suppressed because it is too large Load Diff

1036
wit/90800.ini Normal file

File diff suppressed because it is too large Load Diff

1036
wit/Untitled-1.ini Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,26 +1,24 @@
import pytorch_lightning as pl
import torch import torch
from model.qwen_module import QwenModule from model.qwen_module import QwenModule
from model.modeling_wit import QwenRunner from model.qwen_module import ModelRunner
from model.tokenization_qwen import QWenTokenizer
import numpy as np import numpy as np
import configuration
import dataset.dataset as ds import dataset.dataset as ds
import dataset.node_tree as nt
if __name__ == "__main__": if __name__ == "__main__":
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval() qwen.eval()
conf = qwen.config conf = qwen.config
torch.manual_seed(conf.seed) torch.manual_seed(conf.seed)
np.random.seed(conf.seed) np.random.seed(conf.seed)
runner = QwenRunner(qwen.llm) runner = ModelRunner(qwen.llm)
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64) # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
@ -43,4 +41,4 @@ if __name__ == "__main__":
if item[i] != next_token: if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token)) node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR") print(str(item[i]) + " " + str(next_token) + " ERROR")
node.print() # node.print()

View File

@ -1,10 +1,3 @@
import copy
import math
import os
import sys
import gc
from tqdm import auto as tqdm_lib
import json
from typing import Optional, Tuple, Union, Callable, List, Any, Generator from typing import Optional, Tuple, Union, Callable, List, Any, Generator
from einops import rearrange from einops import rearrange
@ -13,36 +6,22 @@ import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch import nn from torch import nn
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from model.qwen_generation_utils import (
make_context,
decode_tokens,
)
sys.path.append("..")
from tools import show
from tools import mem_tracker
# tracker = mem_tracker.MemTracker()
# tracker.track()
class RMSNorm(torch.nn.Module): class QWenModel(nn.Module):
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x): def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight norm = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return norm.type_as(x) * self.weight
class Block(nn.Module):
class QWenAttention(nn.Module): class Attention(nn.Module):
def __init__(self, config, index): def __init__(self, config, index):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -64,8 +43,7 @@ class QWenAttention(nn.Module):
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape) return tensor.view(new_shape)
class MLP(nn.Module):
class QWenMLP(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
ff_dim_in = config.intermediate_size // 2 ff_dim_in = config.intermediate_size // 2
@ -73,32 +51,28 @@ class QWenMLP(nn.Module):
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias) self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
class QWenBlock(nn.Module):
def __init__(self, config, index): def __init__(self, config, index):
super().__init__() super().__init__()
self.ln_1 = RMSNorm( self.ln_1 = QWenModel.RMSNorm(
config.hidden_size, config.hidden_size,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.attn = QWenAttention(config, index) self.attn = QWenModel.Block.Attention(config, index)
self.ln_2 = RMSNorm( self.ln_2 = QWenModel.RMSNorm(
config.hidden_size, config.hidden_size,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
self.mlp = QWenMLP(config) self.mlp = QWenModel.Block.MLP(config)
self.index = index self.index = index
class QWenModel(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.hidden_size) self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.drop = nn.Dropout(config.emb_dropout_prob) self.drop = nn.Dropout(config.emb_dropout_prob)
self.dim = config.hidden_size // config.num_attention_heads self.dim = config.hidden_size // config.num_attention_heads
self.h = nn.ModuleList([QWenBlock(config, i) for i in range(config.num_hidden_layers)]) self.h = nn.ModuleList([QWenModel.Block(config, i) for i in range(config.num_hidden_layers)])
self.ln_f = RMSNorm( self.ln_f = QWenModel.RMSNorm(
config.hidden_size, config.hidden_size,
eps=config.layer_norm_epsilon, eps=config.layer_norm_epsilon,
) )
@ -133,127 +107,7 @@ class QWenLMHeadModel(nn.Module):
self.config = config self.config = config
self.transformer = QWenModel(config) self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hook_attention = None
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
runner = QwenRunner(self)
return runner.forwardQWen(input_ids, labels)
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
model = cls._load_pretrained_model(resolved_archive_file)
return model
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
def _load_pretrained_model(cls, resolved_archive_file):
start_prefix = ""
model_to_load = cls
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
for shard_file in resolved_archive_file:
state_dict = safe_load_file(shard_file)
cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
del state_dict # force memory release
gc.collect()
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
return cls
class QwenRunner:
def __init__(self, qwen):
self.qwen = qwen
# torch.backends.cuda.enable_flash_sdp(True)
@torch.no_grad()
def ChatTokens(self, input_ids, sample=True):
qwen = self.qwen
input_ids = input_ids.to(next(qwen.parameters()).device)
outputs, loss = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
if sample:
next_token_scores = self.top_p(next_token_scores)
return self.sample(next_token_scores)
else:
return torch.sort(next_token_scores, descending=True)
@torch.no_grad()
def Chat(
self,
tokenizer,
query: str,
query_assistant: str,
gen_length=0,
system: str = "You are a helpful assistant.",
history=[],
):
qwen = self.qwen
history = copy.deepcopy(history)
self.qwen.config.pad_token_id = tokenizer.eod_id
self.qwen.config.eos_token_id = tokenizer.eod_id
raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
input_length = input_ids.shape[1]
while True:
outputs, loss = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
next_token_scores = self.top_p(next_token_scores)
next_tokens = self.sample(next_token_scores)
finish, next_tokens = self.isFinish(next_tokens)
if finish:
break
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
break
decoded, response, end_reason = decode_tokens(
input_ids[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
errors="replace",
)
history.append((query, response))
return input_ids[0].cpu().tolist(), history, decoded
def _rotate_half(self, x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, t, freqs): def apply_rotary_pos_emb(self, t, freqs):
rot_dim = freqs[0].shape[-1] rot_dim = freqs[0].shape[-1]
@ -261,81 +115,20 @@ class QwenRunner:
t_float = t.float() t_float = t.float()
t_rot = t_float[..., :rot_dim] t_rot = t_float[..., :rot_dim]
t_pass = t_float[..., rot_dim:] t_pass = t_float[..., rot_dim:]
t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin)
x = rearrange(t_rot, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
_rotate_half = torch.cat((-x2, x1), dim=-1)
t_rot = (t_rot * cos) + (_rotate_half * sin)
return torch.cat((t_rot, t_pass), dim=-1).type_as(t) return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
def split_heads( def forward(
self,
attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
):
atten = attention
mixed_x_layer = atten.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
return query, key, value
def pos_emb(self, query, key, rotary_pos_emb_list):
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
return query, key
def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
attn_output = attention.c_proj(context_layer)
return attn_output
def build_mask(self, query):
size = query.size(1)
causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size)
return causal_mask
def forwardAttention(
self,
attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
):
query, key, value = self.split_heads(attention, hidden_states)
query, key = self.pos_emb(query, key, rotary_pos_emb_list)
causal_mask = self.build_mask(query)
return self.attention(attention, query, key, value, causal_mask)
def forwardQWenBlock(
self,
block,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
):
layernorm_output = block.ln_1(hidden_states)
attn_outputs = self.forwardAttention(block.attn, layernorm_output, rotary_pos_emb_list)
layernorm_input = attn_outputs + hidden_states
layernorm_output = block.ln_2(layernorm_input)
a1 = block.mlp.w1(layernorm_output)
a2 = block.mlp.w2(layernorm_output)
intermediate_parallel = a1 * F.silu(a2)
mlp_output = block.mlp.c_proj(intermediate_parallel)
hidden_states = layernorm_input + mlp_output
return hidden_states
def forwardQWen(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
): ):
transfm = self.qwen.transformer transfm = self.transformer
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = transfm.wte(input_ids) hidden_states = transfm.wte(input_ids)
@ -348,13 +141,52 @@ class QwenRunner:
hidden_states = transfm.drop(hidden_states) hidden_states = transfm.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
for block in transfm.h: for index, block in enumerate(transfm.h):
hidden_states = self.forwardQWenBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list) layernorm_output = block.ln_1(hidden_states)
# split_heads
atten = block.attn
mixed_x_layer = atten.c_attn(layernorm_output)
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
# pos_emb
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
# build_mask
size = query.size(1)
causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(
1, 1, size, size
)
# attention
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=causal_mask).transpose(1, 2)
if self.hook_attention:
self.hook_attention(query, key, causal_mask, index)
context_layer = block.attn._merge_heads(attn_output, block.attn.num_heads, block.attn.head_dim)
attn_outputs = block.attn.c_proj(context_layer)
layernorm_input = attn_outputs + hidden_states
layernorm_output = block.ln_2(layernorm_input)
a1 = block.mlp.w1(layernorm_output)
a2 = block.mlp.w2(layernorm_output)
intermediate_parallel = a1 * F.silu(a2)
mlp_output = block.mlp.c_proj(intermediate_parallel)
hidden_states = layernorm_input + mlp_output
hidden_states = transfm.ln_f(hidden_states) hidden_states = transfm.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape) hidden_states = hidden_states.view(output_shape)
lm_logits = self.qwen.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
loss = None loss = None
if labels is not None: if labels is not None:
@ -362,52 +194,9 @@ class QwenRunner:
shift_labels = labels[..., 1:].contiguous().view(-1) shift_labels = labels[..., 1:].contiguous().view(-1)
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_logits = shift_logits.view(-1, shift_logits.size(-1))
mask = shift_labels < self.qwen.config.vocab_size mask = shift_labels < self.config.vocab_size
shift_labels = shift_labels[mask] shift_labels = shift_labels[mask]
shift_logits = shift_logits[mask] shift_logits = shift_logits[mask]
# m = torch.max(shift_logits, 1).indices.cpu().numpy()
# ll = shift_labels.cpu().numpy()
loss = CrossEntropyLoss()(shift_logits, shift_labels) loss = CrossEntropyLoss()(shift_logits, shift_labels)
return lm_logits, loss return lm_logits, loss
def prepareInput(self, tokenizer, query, query_assistant, history, system):
return make_context(tokenizer, query, query_assistant, history=history, system=system)
def repetition_penalty(self, input_ids, next_token_scores):
penalty = self.qwen.config.repetition_penalty
score = torch.gather(next_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
return next_token_scores
def top_p(self, next_token_scores):
top_p = self.qwen.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
return next_token_scores
def sample(self, next_token_scores):
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens
def isFinish(self, next_tokens):
pad_token_id = self.qwen.config.pad_token_id
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device)
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
self.unfinished_sequences = self.unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
return self.unfinished_sequences.max() == 0, next_tokens[:, None]

View File

@ -1,5 +1,13 @@
import os
import gc
import json
from tqdm import auto as tqdm_lib
from torch import nn
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from functools import cache from functools import cache
from typing import Dict, Optional from typing import Dict, Optional, Union
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
@ -9,6 +17,154 @@ from model.modeling_wit import QWenLMHeadModel
from configuration import ModelConfig, TrainConfig from configuration import ModelConfig, TrainConfig
class LoadModule:
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
model = LoadModule._load_pretrained_model(cls, resolved_archive_file)
return model
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
def _load_pretrained_model(cls, resolved_archive_file):
start_prefix = ""
model_to_load = cls
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
for shard_file in resolved_archive_file:
state_dict = safe_load_file(shard_file)
LoadModule._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
del state_dict # force memory release
gc.collect()
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
return cls
class ModelRunner:
def __init__(self, qwen):
self.qwen = qwen
@torch.no_grad()
def ChatTokens(self, input_ids, sample=True):
qwen = self.qwen
input_ids = input_ids.to(next(qwen.parameters()).device)
outputs, loss = qwen.forward(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
if sample:
next_token_scores = self.top_p(next_token_scores)
return self.sample(next_token_scores)
else:
return torch.sort(next_token_scores, descending=True)
@torch.no_grad()
def Chat(
self,
tokenizer,
query: str,
query_assistant: str,
gen_length=0,
system: str = "You are a helpful assistant.",
history=[],
):
qwen = self.qwen
history = copy.deepcopy(history)
self.qwen.config.pad_token_id = tokenizer.eod_id
self.qwen.config.eos_token_id = tokenizer.eod_id
raw_text, context_tokens = qwen.prepareInput(tokenizer, query, query_assistant, history, system)
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
input_length = input_ids.shape[1]
while True:
outputs, loss = self.forward(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
next_token_scores = self.top_p(next_token_scores)
next_tokens = self.sample(next_token_scores)
finish, next_tokens = self.isFinish(next_tokens)
if finish:
break
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
break
decoded, response, end_reason = decode_tokens(
input_ids[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
errors="replace",
)
history.append((query, response))
return input_ids[0].cpu().tolist(), history, decoded
def prepareInput(self, tokenizer, query, query_assistant, history, system):
return make_context(tokenizer, query, query_assistant, history=history, system=system)
def repetition_penalty(self, input_ids, next_token_scores):
penalty = self.qwen.config.repetition_penalty
score = torch.gather(next_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
return next_token_scores
def top_p(self, next_token_scores):
top_p = self.qwen.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
return next_token_scores
def sample(self, next_token_scores):
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens
def isFinish(self, next_tokens):
pad_token_id = self.qwen.config.pad_token_id
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device)
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
self.unfinished_sequences = self.unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
return self.unfinished_sequences.max() == 0, next_tokens[:, None]
class QwenModule(pl.LightningModule): class QwenModule(pl.LightningModule):
def __init__(self, conf: TrainConfig = None): def __init__(self, conf: TrainConfig = None):
self.config = conf self.config = conf
@ -24,7 +180,7 @@ class QwenModule(pl.LightningModule):
if pretrained_model_dir != None: if pretrained_model_dir != None:
from modelscope import snapshot_download from modelscope import snapshot_download
model = model.from_pretrained(snapshot_download(pretrained_model_dir)) model = LoadModule.from_pretrained(snapshot_download(pretrained_model_dir))
self.llm = self.register_core_module(model) self.llm = self.register_core_module(model)
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.use_tril_attention_mask = use_tril_attention_mask self.use_tril_attention_mask = use_tril_attention_mask

51
wit/query_block_output.py Normal file
View File

@ -0,0 +1,51 @@
import torch
from model.qwen_module import QwenModule
from model.qwen_module import ModelRunner
import numpy as np
import math
import sys
sys.path.append("..")
from tools import show
import dataset.dataset as ds
if __name__ == "__main__":
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
runner = ModelRunner(qwen.llm)
def DumpQK(query, key, causal_mask, index):
size = query.shape[2]
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
attn_weight = attn_weight * attn_mask
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = attn_weight * attn_mask
qk = attn_weight[0]
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
show.DumpTensorToImage(qk, prePath, GridValue=255)
# qk_seq.append(qk)
# qk_index = size
qwen.llm.hook_attention = DumpQK
batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64)
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
print(sorted_logits.detach().cpu().numpy())
print(sorted_indices.detach().cpu().numpy())

View File

@ -2,7 +2,7 @@ import pytorch_lightning as pl
import torch import torch
from model.qwen_module import QwenModule from model.qwen_module import QwenModule
from model.modeling_wit import QwenRunner from model.modeling_wit import ModelRunner
from model.tokenization_qwen import QWenTokenizer from model.tokenization_qwen import QWenTokenizer
import numpy as np import numpy as np

View File

@ -18,7 +18,7 @@ if __name__ == "__main__":
conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat" conf.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
conf.learning_rate = 0.0001 conf.learning_rate = 0.0001
conf.use_tril_attention_mask = None conf.use_tril_attention_mask = None
conf.precision = "32-true" # "precision:bf16-mixed,16-mixed,32-true" conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
conf.train_batch_size = 16 conf.train_batch_size = 16
conf.val_batch_size = 2 conf.val_batch_size = 2
conf.num_proc = 8 conf.num_proc = 8
@ -38,7 +38,7 @@ if __name__ == "__main__":
config.vocab_size = 32 config.vocab_size = 32
config.hidden_size = 128 # 128 1024 2048 32 config.hidden_size = 128 # 128 1024 2048 32
config.num_hidden_layers = 3 # 6 12 24 3 config.num_hidden_layers = 3 # 6 12 24 3
config.num_attention_heads = 16 # 8 8 16 config.num_attention_heads = 8 # 8 8 16
torch.manual_seed(conf.seed) torch.manual_seed(conf.seed)
np.random.seed(conf.seed) np.random.seed(conf.seed)