Update research.
This commit is contained in:
parent
bba27e3444
commit
3f296ccdb2
|
@ -7,6 +7,7 @@ from transformers import AutoConfig
|
||||||
|
|
||||||
from modeling_qwen import QWenLMHeadModel
|
from modeling_qwen import QWenLMHeadModel
|
||||||
from modeling_qwen import QwenRunner
|
from modeling_qwen import QwenRunner
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
@ -69,6 +70,9 @@ def Dump_lm_head_weight(model):
|
||||||
|
|
||||||
# Dump_lm_head_weight(model)
|
# Dump_lm_head_weight(model)
|
||||||
|
|
||||||
|
qk_sum = []
|
||||||
|
qk_index = []
|
||||||
|
|
||||||
|
|
||||||
def DumpQK(query, key, causal_mask, index):
|
def DumpQK(query, key, causal_mask, index):
|
||||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
|
@ -77,12 +81,13 @@ def DumpQK(query, key, causal_mask, index):
|
||||||
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
||||||
attn_weight = attn_weight * attn_mask
|
attn_weight = attn_weight * attn_mask
|
||||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||||
|
attn_weight = attn_weight * attn_mask
|
||||||
size = query.shape[2]
|
size = query.shape[2]
|
||||||
qk = attn_weight[0]
|
qk = attn_weight[0]
|
||||||
# prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
# prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
||||||
# show.DumpTensorToImage(qk, prePath, GridValue=255)
|
# show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||||
prePath = "./temp/" + "q@k_sum_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
qk_sum.append(qk.sum(0))
|
||||||
show.DumpTensorToImage(qk.sum(0), prePath, GridValue=255)
|
qk_index.append(size)
|
||||||
|
|
||||||
|
|
||||||
class ResearchRunner(QwenRunner):
|
class ResearchRunner(QwenRunner):
|
||||||
|
@ -101,6 +106,14 @@ class ResearchRunner(QwenRunner):
|
||||||
attn_output = attention.c_proj(context_layer)
|
attn_output = attention.c_proj(context_layer)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
def sample(self, next_token_scores):
|
||||||
|
qk_sum_cat = torch.stack(qk_sum, 0)
|
||||||
|
qk_sum.clear()
|
||||||
|
prePath = "./temp/" + "q@k_sum_seq_" + str(qk_index[-1]) + ".png"
|
||||||
|
show.DumpTensorToImage(qk_sum_cat, prePath, GridValue=255)
|
||||||
|
|
||||||
|
return super().sample(next_token_scores)
|
||||||
|
|
||||||
def prepareInput(self, tokenizer, query, query_assistant, history, system):
|
def prepareInput(self, tokenizer, query, query_assistant, history, system):
|
||||||
start_to = [151644]
|
start_to = [151644]
|
||||||
n_to = [198]
|
n_to = [198]
|
||||||
|
@ -129,6 +142,7 @@ runner = ResearchRunner(model)
|
||||||
output_ids, history, decoded = runner.Chat(tokenizer, "你好!!", "")
|
output_ids, history, decoded = runner.Chat(tokenizer, "你好!!", "")
|
||||||
print(decoded)
|
print(decoded)
|
||||||
|
|
||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
for i, token in enumerate(output_ids):
|
for i, token in enumerate(output_ids):
|
||||||
de = tokenizer.decode([token])
|
de = tokenizer.decode([token])
|
||||||
|
|
Loading…
Reference in New Issue