Compare commits
No commits in common. "a4fafd460fed39385f4ba9ff9ca6731db8ddfb38" and "1811b9611abb70a72f54f54e08c9584e21af1d46" have entirely different histories.
a4fafd460f
...
1811b9611a
|
@ -204,20 +204,46 @@ class QwenRunner:
|
||||||
pad_token_id = qwen.config.pad_token_id
|
pad_token_id = qwen.config.pad_token_id
|
||||||
|
|
||||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||||
|
this_peer_finished = False
|
||||||
while True:
|
while True:
|
||||||
outputs = self.forwardQWen(input_ids)
|
outputs = self.forwardQWen(input_ids)
|
||||||
next_token_scores = outputs[:, -1, :]
|
next_token_scores = outputs[:, -1, :]
|
||||||
|
|
||||||
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
# repetition_penalty
|
||||||
next_token_scores = self.top_p(next_token_scores)
|
penalty = qwen.config.repetition_penalty
|
||||||
next_tokens = self.sample(next_token_scores)
|
score = torch.gather(next_token_scores, 1, input_ids)
|
||||||
|
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||||
|
score = torch.where(score < 0, score * penalty, score / penalty)
|
||||||
|
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
|
||||||
|
|
||||||
|
# top_p
|
||||||
|
top_p = qwen.config.top_p
|
||||||
|
filter_value = -float("Inf")
|
||||||
|
min_tokens_to_keep = 1
|
||||||
|
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
|
||||||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||||||
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
||||||
|
# Keep at least min_tokens_to_keep
|
||||||
|
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
|
||||||
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
|
||||||
|
|
||||||
|
# sample
|
||||||
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||||
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
if unfinished_sequences.max() == 0:
|
if unfinished_sequences.max() == 0:
|
||||||
|
this_peer_finished = True
|
||||||
|
|
||||||
|
if this_peer_finished:
|
||||||
break
|
break
|
||||||
|
|
||||||
decoded, response, end_reason = decode_tokens(
|
decoded, response, end_reason = decode_tokens(
|
||||||
|
@ -228,7 +254,7 @@ class QwenRunner:
|
||||||
errors="replace",
|
errors="replace",
|
||||||
)
|
)
|
||||||
history.append((query, response))
|
history.append((query, response))
|
||||||
return input_ids[0].cpu().tolist(), history, decoded
|
return response, history, decoded
|
||||||
|
|
||||||
def _rotate_half(self, x):
|
def _rotate_half(self, x):
|
||||||
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
||||||
|
@ -353,31 +379,3 @@ class QwenRunner:
|
||||||
# loss.backward()
|
# loss.backward()
|
||||||
|
|
||||||
return lm_logits
|
return lm_logits
|
||||||
|
|
||||||
def repetition_penalty(self, input_ids, next_token_scores):
|
|
||||||
penalty = self.qwen.config.repetition_penalty
|
|
||||||
score = torch.gather(next_token_scores, 1, input_ids)
|
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
|
||||||
score = torch.where(score < 0, score * penalty, score / penalty)
|
|
||||||
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
|
|
||||||
return next_token_scores
|
|
||||||
|
|
||||||
def top_p(self, next_token_scores):
|
|
||||||
top_p = self.qwen.config.top_p
|
|
||||||
filter_value = -float("Inf")
|
|
||||||
min_tokens_to_keep = 1
|
|
||||||
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
|
|
||||||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
||||||
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
|
||||||
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
|
||||||
# Keep at least min_tokens_to_keep
|
|
||||||
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
|
|
||||||
# scatter sorted tensors to original indexing
|
|
||||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
||||||
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
|
|
||||||
return next_token_scores
|
|
||||||
|
|
||||||
def sample(self, next_token_scores):
|
|
||||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
||||||
return next_tokens
|
|
||||||
|
|
|
@ -70,7 +70,12 @@ def Dump_lm_head_weight(model):
|
||||||
# Dump_lm_head_weight(model)
|
# Dump_lm_head_weight(model)
|
||||||
|
|
||||||
|
|
||||||
def DumpQK(query, key, causal_mask, index):
|
class ResearchRunner(QwenRunner):
|
||||||
|
def attention(self, attention, query, key, value, causal_mask):
|
||||||
|
query = query.permute(0, 2, 1, 3)
|
||||||
|
key = key.permute(0, 2, 1, 3)
|
||||||
|
value = value.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||||
|
@ -79,39 +84,20 @@ def DumpQK(query, key, causal_mask, index):
|
||||||
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
||||||
qk = attn_weight * attn_mask
|
qk = attn_weight * attn_mask
|
||||||
qk = qk[0]
|
qk = qk[0]
|
||||||
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png"
|
||||||
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||||
|
|
||||||
|
|
||||||
class ResearchRunner(QwenRunner):
|
|
||||||
def __init__(self, model):
|
|
||||||
super().__init__(model)
|
|
||||||
|
|
||||||
def attention(self, attention, query, key, value, causal_mask):
|
|
||||||
query = query.permute(0, 2, 1, 3)
|
|
||||||
key = key.permute(0, 2, 1, 3)
|
|
||||||
value = value.permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
DumpQK(query, key, causal_mask, attention.index)
|
|
||||||
|
|
||||||
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
|
||||||
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
|
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
|
||||||
attn_output = attention.c_proj(context_layer)
|
attn_output = attention.c_proj(context_layer)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
runner = ResearchRunner(model)
|
runner = ResearchRunner(model)
|
||||||
|
|
||||||
# 第一轮对话
|
# 第一轮对话
|
||||||
output_ids, history, decoded = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
|
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "日本的首都是")
|
||||||
print(decoded)
|
print(decode_tokens)
|
||||||
|
|
||||||
tokens = []
|
|
||||||
for i, token in enumerate(output_ids):
|
|
||||||
de = tokenizer.decode([token])
|
|
||||||
de = str(i).zfill(3) + " : " + repr(de)
|
|
||||||
tokens.append(de)
|
|
||||||
print(de)
|
|
||||||
|
|
||||||
# <|im_start|>system
|
# <|im_start|>system
|
||||||
# You are a helpful assistant.<|im_end|>
|
# You are a helpful assistant.<|im_end|>
|
||||||
# <|im_start|>user
|
# <|im_start|>user
|
||||||
|
@ -120,7 +106,5 @@ for i, token in enumerate(output_ids):
|
||||||
# 日本的首都东京。<|im_end|>
|
# 日本的首都东京。<|im_end|>
|
||||||
# <|endoftext|>
|
# <|endoftext|>
|
||||||
|
|
||||||
show.DumpListToFile(tokens, "./temp/token_decode_list.txt")
|
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
||||||
|
|
||||||
if decoded.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
|
|
||||||
raise ()
|
raise ()
|
||||||
|
|
Loading…
Reference in New Issue