137 lines
3.7 KiB
Python
137 lines
3.7 KiB
Python
|
import torch
|
||
|
import sys
|
||
|
|
||
|
# from modelscope import snapshot_download
|
||
|
from transformers import AutoTokenizer
|
||
|
from transformers import AutoConfig
|
||
|
|
||
|
from modeling_qwen import QWenLMHeadModel
|
||
|
from modeling_qwen import QwenRunner
|
||
|
|
||
|
|
||
|
from qwen_generation_utils import (
|
||
|
make_context,
|
||
|
decode_tokens,
|
||
|
)
|
||
|
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
sys.path.append("..")
|
||
|
from tools import show
|
||
|
|
||
|
seed = 4321
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
|
||
|
# model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
|
||
|
model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
|
||
|
|
||
|
config, kwargs = AutoConfig.from_pretrained(
|
||
|
"./",
|
||
|
return_unused_kwargs=True,
|
||
|
trust_remote_code=True,
|
||
|
code_revision=None,
|
||
|
_commit_hash=None,
|
||
|
)
|
||
|
|
||
|
model = QWenLMHeadModel(config)
|
||
|
print(model)
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||
|
|
||
|
model = model.from_pretrained(model_dir)
|
||
|
if torch.cuda.device_count() > 0:
|
||
|
model = model.cuda()
|
||
|
model = model.eval()
|
||
|
|
||
|
|
||
|
class ResearchRunner(QwenRunner):
|
||
|
def __init__(self, model):
|
||
|
super().__init__(model)
|
||
|
|
||
|
def forwardQWen(
|
||
|
self,
|
||
|
input_ids=None,
|
||
|
labels=None,
|
||
|
):
|
||
|
transfm = self.qwen.transformer
|
||
|
input_shape = input_ids.size()
|
||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||
|
hidden_states = transfm.wte(input_ids)
|
||
|
kv_seq_len = hidden_states.size()[1]
|
||
|
|
||
|
transfm.update_rotary_pos_emb_cache(kv_seq_len, ntk_alpha=1.0)
|
||
|
cos, sin = transfm._rotary_pos_emb_cache
|
||
|
rotary_pos_emb_list = [[cos[:, :kv_seq_len], sin[:, :kv_seq_len]]]
|
||
|
|
||
|
hidden_states = transfm.drop(hidden_states)
|
||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||
|
|
||
|
for block in transfm.h:
|
||
|
self.forwardQWenBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
|
||
|
break
|
||
|
|
||
|
def forwardQWenBlock(
|
||
|
self,
|
||
|
block,
|
||
|
hidden_states,
|
||
|
rotary_pos_emb_list=None,
|
||
|
):
|
||
|
layernorm_output = block.ln_1(hidden_states)
|
||
|
self.forwardAttention(block.attn, layernorm_output, rotary_pos_emb_list)
|
||
|
|
||
|
def attention(self, attention, query, key, value, causal_mask):
|
||
|
query = query.permute(0, 2, 1, 3)
|
||
|
key = key.permute(0, 2, 1, 3)
|
||
|
value = value.permute(0, 2, 1, 3)
|
||
|
global q
|
||
|
global k
|
||
|
query = query[:, head_group_index, :, :]
|
||
|
key = key[:, head_group_index, :, :]
|
||
|
q = torch.cat([q, query], 1)
|
||
|
k = torch.cat([k, key], 1)
|
||
|
|
||
|
|
||
|
head_group_index = 0
|
||
|
total_token = 151851
|
||
|
topk = 10
|
||
|
|
||
|
tokens_str = []
|
||
|
for token in range(total_token):
|
||
|
decoded, response, end_reason = decode_tokens(
|
||
|
[token],
|
||
|
tokenizer,
|
||
|
raw_text_len=0,
|
||
|
context_length=0,
|
||
|
errors="replace",
|
||
|
)
|
||
|
tokens_str.append(repr(decoded))
|
||
|
|
||
|
patch_end = list(range(0, total_token, 1000))
|
||
|
patch_end = patch_end[1:] + [total_token]
|
||
|
patch_start = 0
|
||
|
q = torch.zeros((1, 0, 128), dtype=float).to(next(model.parameters()).device)
|
||
|
k = torch.zeros((1, 0, 128), dtype=float).to(next(model.parameters()).device)
|
||
|
for end in patch_end:
|
||
|
tokens = list(range(patch_start, end))
|
||
|
patch_start = end
|
||
|
input_ids = torch.tensor([tokens]).to(next(model.parameters()).device)
|
||
|
runner = ResearchRunner(model)
|
||
|
runner.forwardQWen(input_ids)
|
||
|
|
||
|
q = q[0, :, :]
|
||
|
k = k[0, :, :].permute(1, 0)
|
||
|
token_topk = []
|
||
|
for i in range(total_token):
|
||
|
subq = q[i, :]
|
||
|
qk = subq @ k
|
||
|
values, indices = torch.topk(qk, topk)
|
||
|
item = str(i).zfill(7) + " " + tokens_str[i] + " : "
|
||
|
for index in indices:
|
||
|
item += tokens_str[index] + " "
|
||
|
token_topk.append(item)
|
||
|
|
||
|
show.DumpListToFile(token_topk, "./temp/qwen_token_qk_topk_head_group_" + str(head_group_index) + ".txt")
|
||
|
|
||
|
print("decoded")
|