Compare commits
2 Commits
19491d1f4a
...
3f296ccdb2
Author | SHA1 | Date |
---|---|---|
Colin | 3f296ccdb2 | |
Colin | bba27e3444 |
|
@ -198,7 +198,7 @@ class QwenRunner:
|
||||||
):
|
):
|
||||||
qwen = self.qwen
|
qwen = self.qwen
|
||||||
history = copy.deepcopy(history)
|
history = copy.deepcopy(history)
|
||||||
raw_text, context_tokens = make_context(tokenizer, query, query_assistant, history=history, system=system)
|
raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
|
||||||
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
|
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
|
||||||
eos_token_id_tensor = torch.tensor([qwen.config.eos_token_id]).to(input_ids.device)
|
eos_token_id_tensor = torch.tensor([qwen.config.eos_token_id]).to(input_ids.device)
|
||||||
pad_token_id = qwen.config.pad_token_id
|
pad_token_id = qwen.config.pad_token_id
|
||||||
|
@ -354,6 +354,9 @@ class QwenRunner:
|
||||||
|
|
||||||
return lm_logits
|
return lm_logits
|
||||||
|
|
||||||
|
def prepareInput(self, tokenizer, query, query_assistant, history, system):
|
||||||
|
return make_context(tokenizer, query, query_assistant, history=history, system=system)
|
||||||
|
|
||||||
def repetition_penalty(self, input_ids, next_token_scores):
|
def repetition_penalty(self, input_ids, next_token_scores):
|
||||||
penalty = self.qwen.config.repetition_penalty
|
penalty = self.qwen.config.repetition_penalty
|
||||||
score = torch.gather(next_token_scores, 1, input_ids)
|
score = torch.gather(next_token_scores, 1, input_ids)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from transformers import AutoConfig
|
||||||
|
|
||||||
from modeling_qwen import QWenLMHeadModel
|
from modeling_qwen import QWenLMHeadModel
|
||||||
from modeling_qwen import QwenRunner
|
from modeling_qwen import QwenRunner
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
@ -69,18 +70,24 @@ def Dump_lm_head_weight(model):
|
||||||
|
|
||||||
# Dump_lm_head_weight(model)
|
# Dump_lm_head_weight(model)
|
||||||
|
|
||||||
|
qk_sum = []
|
||||||
|
qk_index = []
|
||||||
|
|
||||||
|
|
||||||
def DumpQK(query, key, causal_mask, index):
|
def DumpQK(query, key, causal_mask, index):
|
||||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
|
||||||
size = query.shape[2]
|
|
||||||
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
|
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
|
||||||
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
||||||
qk = attn_weight * attn_mask
|
attn_weight = attn_weight * attn_mask
|
||||||
qk = qk[0]
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||||
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
attn_weight = attn_weight * attn_mask
|
||||||
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
size = query.shape[2]
|
||||||
|
qk = attn_weight[0]
|
||||||
|
# prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
||||||
|
# show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||||
|
qk_sum.append(qk.sum(0))
|
||||||
|
qk_index.append(size)
|
||||||
|
|
||||||
|
|
||||||
class ResearchRunner(QwenRunner):
|
class ResearchRunner(QwenRunner):
|
||||||
|
@ -99,18 +106,48 @@ class ResearchRunner(QwenRunner):
|
||||||
attn_output = attention.c_proj(context_layer)
|
attn_output = attention.c_proj(context_layer)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
def sample(self, next_token_scores):
|
||||||
|
qk_sum_cat = torch.stack(qk_sum, 0)
|
||||||
|
qk_sum.clear()
|
||||||
|
prePath = "./temp/" + "q@k_sum_seq_" + str(qk_index[-1]) + ".png"
|
||||||
|
show.DumpTensorToImage(qk_sum_cat, prePath, GridValue=255)
|
||||||
|
|
||||||
|
return super().sample(next_token_scores)
|
||||||
|
|
||||||
|
def prepareInput(self, tokenizer, query, query_assistant, history, system):
|
||||||
|
start_to = [151644]
|
||||||
|
n_to = [198]
|
||||||
|
end_to = [151645]
|
||||||
|
system_str = "system\nYou are a helpful assistant."
|
||||||
|
user_str = "user\n" + query
|
||||||
|
aassistant_str = "assistant\n" + query_assistant
|
||||||
|
|
||||||
|
system_token = start_to + tokenizer.encode(system_str, allowed_special=set()) + end_to + n_to
|
||||||
|
user_token = start_to + tokenizer.encode(user_str, allowed_special=set()) + end_to + n_to
|
||||||
|
aassistant_token = start_to + tokenizer.encode(aassistant_str, allowed_special=set())
|
||||||
|
|
||||||
|
tokens = system_token + user_token + aassistant_token
|
||||||
|
tokens = user_token + aassistant_token
|
||||||
|
tokens = start_to + tokenizer.encode("user\n你好\nassistant\n", allowed_special=set())
|
||||||
|
|
||||||
|
return "", tokens
|
||||||
|
|
||||||
|
|
||||||
runner = ResearchRunner(model)
|
runner = ResearchRunner(model)
|
||||||
|
|
||||||
# 第一轮对话
|
# 第一轮对话
|
||||||
output_ids, history, decoded = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
|
# output_ids, history, decoded = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
|
||||||
|
# print(decoded)
|
||||||
|
|
||||||
|
output_ids, history, decoded = runner.Chat(tokenizer, "你好!!", "")
|
||||||
print(decoded)
|
print(decoded)
|
||||||
|
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
for i, token in enumerate(output_ids):
|
for i, token in enumerate(output_ids):
|
||||||
de = tokenizer.decode([token])
|
de = tokenizer.decode([token])
|
||||||
de = str(i).zfill(3) + " : " + repr(de)
|
de = str(i + 1).zfill(3) + " : " + repr(de)
|
||||||
tokens.append(de)
|
tokens.append(de)
|
||||||
print(de)
|
|
||||||
|
|
||||||
# <|im_start|>system
|
# <|im_start|>system
|
||||||
# You are a helpful assistant.<|im_end|>
|
# You are a helpful assistant.<|im_end|>
|
||||||
|
@ -122,5 +159,9 @@ for i, token in enumerate(output_ids):
|
||||||
|
|
||||||
show.DumpListToFile(tokens, "./temp/token_decode_list.txt")
|
show.DumpListToFile(tokens, "./temp/token_decode_list.txt")
|
||||||
|
|
||||||
if decoded.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
# if decoded.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
||||||
raise ()
|
# raise ()
|
||||||
|
|
||||||
|
|
||||||
|
# normal (x - mean) / (std + eps) => sum(y)==0
|
||||||
|
# softmax exp(x) / sum(exp(x)) => 0 < y < 1 sum(y)==1
|
||||||
|
|
Loading…
Reference in New Issue