Refine research_attention and forward model.

This commit is contained in:
Colin 2024-01-23 13:13:21 +08:00
parent 1811b9611a
commit 11af10e710
2 changed files with 60 additions and 34 deletions

View File

@ -209,30 +209,9 @@ class QwenRunner:
outputs = self.forwardQWen(input_ids) outputs = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :] next_token_scores = outputs[:, -1, :]
# repetition_penalty next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
penalty = qwen.config.repetition_penalty next_token_scores = self.top_p(next_token_scores)
score = torch.gather(next_token_scores, 1, input_ids) next_tokens = self.sample(next_token_scores)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
# top_p
top_p = qwen.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
@ -379,3 +358,31 @@ class QwenRunner:
# loss.backward() # loss.backward()
return lm_logits return lm_logits
def repetition_penalty(self, input_ids, next_token_scores):
penalty = self.qwen.config.repetition_penalty
score = torch.gather(next_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
return next_token_scores
def top_p(self, next_token_scores):
top_p = self.qwen.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
return next_token_scores
def sample(self, next_token_scores):
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens

View File

@ -70,28 +70,44 @@ def Dump_lm_head_weight(model):
# Dump_lm_head_weight(model) # Dump_lm_head_weight(model)
def DumpQK(query, key, causal_mask, index):
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight = torch.softmax(attn_weight, dim=-1)
size = query.shape[2]
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
qk = attn_weight * attn_mask
qk = qk[0]
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
show.DumpTensorToImage(qk, prePath, GridValue=255)
class ResearchRunner(QwenRunner): class ResearchRunner(QwenRunner):
def __init__(self, model):
super().__init__(model)
self.tokenDecode = []
def attention(self, attention, query, key, value, causal_mask): def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3)
scale_factor = 1 / math.sqrt(query.size(-1)) DumpQK(query, key, causal_mask, attention.index)
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight = torch.softmax(attn_weight, dim=-1)
size = query.shape[2]
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
qk = attn_weight * attn_mask
qk = qk[0]
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png"
show.DumpTensorToImage(qk, prePath, GridValue=255)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim) context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
attn_output = attention.c_proj(context_layer) attn_output = attention.c_proj(context_layer)
return attn_output return attn_output
def sample(self, next_token_scores):
next_tokens = super().sample(next_token_scores)
decoded, response, end_reason = decode_tokens(
next_tokens,
tokenizer,
)
self.tokenDecode.append(decoded)
return next_tokens
runner = ResearchRunner(model) runner = ResearchRunner(model)
@ -106,5 +122,8 @@ print(decode_tokens)
# 日本的首都东京。<|im_end|> # 日本的首都东京。<|im_end|>
# <|endoftext|> # <|endoftext|>
show.DumpListToFile(runner.tokenDecode, "./temp/token_decode_list.txt")
if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""": if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""":
raise () raise ()