Update research_attention dump without sum.
This commit is contained in:
parent
3f296ccdb2
commit
185278f3a9
|
@ -200,10 +200,7 @@ class QwenRunner:
|
|||
history = copy.deepcopy(history)
|
||||
raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
|
||||
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
|
||||
eos_token_id_tensor = torch.tensor([qwen.config.eos_token_id]).to(input_ids.device)
|
||||
pad_token_id = qwen.config.pad_token_id
|
||||
|
||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
while True:
|
||||
outputs = self.forwardQWen(input_ids)
|
||||
next_token_scores = outputs[:, -1, :]
|
||||
|
@ -211,14 +208,10 @@ class QwenRunner:
|
|||
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
||||
next_token_scores = self.top_p(next_token_scores)
|
||||
next_tokens = self.sample(next_token_scores)
|
||||
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
if unfinished_sequences.max() == 0:
|
||||
finish, next_tokens = self.isFinish(next_tokens)
|
||||
if finish:
|
||||
break
|
||||
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
decoded, response, end_reason = decode_tokens(
|
||||
input_ids[0],
|
||||
|
@ -384,3 +377,13 @@ class QwenRunner:
|
|||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
return next_tokens
|
||||
|
||||
def isFinish(self, next_tokens):
|
||||
pad_token_id = self.qwen.config.pad_token_id
|
||||
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device)
|
||||
|
||||
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||
)
|
||||
return self.unfinished_sequences.max() == 0, next_tokens[:, None]
|
||||
|
|
|
@ -40,6 +40,8 @@ print(model)
|
|||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
|
||||
model = model.from_pretrained(model_dir)
|
||||
if torch.cuda.device_count() > 0:
|
||||
model = model.cuda()
|
||||
model = model.eval()
|
||||
|
||||
|
||||
|
@ -70,11 +72,14 @@ def Dump_lm_head_weight(model):
|
|||
|
||||
# Dump_lm_head_weight(model)
|
||||
|
||||
qk_sum = []
|
||||
qk_index = []
|
||||
qk_seq = []
|
||||
qk_index = None
|
||||
|
||||
|
||||
def DumpQK(query, key, causal_mask, index):
|
||||
global qk_seq
|
||||
global qk_index
|
||||
size = query.shape[2]
|
||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
|
||||
|
@ -82,12 +87,11 @@ def DumpQK(query, key, causal_mask, index):
|
|||
attn_weight = attn_weight * attn_mask
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
attn_weight = attn_weight * attn_mask
|
||||
size = query.shape[2]
|
||||
qk = attn_weight[0]
|
||||
# prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
||||
# show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||
qk_sum.append(qk.sum(0))
|
||||
qk_index.append(size)
|
||||
qk_seq.append(qk)
|
||||
qk_index = size
|
||||
|
||||
|
||||
class ResearchRunner(QwenRunner):
|
||||
|
@ -106,14 +110,6 @@ class ResearchRunner(QwenRunner):
|
|||
attn_output = attention.c_proj(context_layer)
|
||||
return attn_output
|
||||
|
||||
def sample(self, next_token_scores):
|
||||
qk_sum_cat = torch.stack(qk_sum, 0)
|
||||
qk_sum.clear()
|
||||
prePath = "./temp/" + "q@k_sum_seq_" + str(qk_index[-1]) + ".png"
|
||||
show.DumpTensorToImage(qk_sum_cat, prePath, GridValue=255)
|
||||
|
||||
return super().sample(next_token_scores)
|
||||
|
||||
def prepareInput(self, tokenizer, query, query_assistant, history, system):
|
||||
start_to = [151644]
|
||||
n_to = [198]
|
||||
|
@ -128,10 +124,21 @@ class ResearchRunner(QwenRunner):
|
|||
|
||||
tokens = system_token + user_token + aassistant_token
|
||||
tokens = user_token + aassistant_token
|
||||
tokens = start_to + tokenizer.encode("user\n你好\nassistant\n", allowed_special=set())
|
||||
tokens = start_to + tokenizer.encode("user\nHi你好\nassistant\n", allowed_special=set())
|
||||
|
||||
return "", tokens
|
||||
|
||||
def isFinish(self, next_tokens):
|
||||
global qk_seq
|
||||
finish, next = super().isFinish(next_tokens)
|
||||
if finish:
|
||||
for i, s in enumerate(qk_seq):
|
||||
prePath = "./temp/" + "q@k_layer_" + str(i) + ".png"
|
||||
show.DumpTensorToImage(s, prePath, GridValue=255)
|
||||
else:
|
||||
qk_seq = []
|
||||
return finish, next
|
||||
|
||||
|
||||
runner = ResearchRunner(model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue