From 11af10e7106db9970db246527d7261e39a9163b2 Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 23 Jan 2024 13:13:21 +0800 Subject: [PATCH] Refine research_attention and forward model. --- qwen/modeling_qwen.py | 55 +++++++++++++++++++++----------------- qwen/research_attention.py | 39 ++++++++++++++++++++------- 2 files changed, 60 insertions(+), 34 deletions(-) diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 8da5842..7d83af9 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -209,30 +209,9 @@ class QwenRunner: outputs = self.forwardQWen(input_ids) next_token_scores = outputs[:, -1, :] - # repetition_penalty - penalty = qwen.config.repetition_penalty - score = torch.gather(next_token_scores, 1, input_ids) - # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities - score = torch.where(score < 0, score * penalty, score / penalty) - next_token_scores = next_token_scores.scatter_(1, input_ids, score) - - # top_p - top_p = qwen.config.top_p - filter_value = -float("Inf") - min_tokens_to_keep = 1 - sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False) - cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) - # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) - sorted_indices_to_remove = cumulative_probs <= (1 - top_p) - # Keep at least min_tokens_to_keep - sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 - # scatter sorted tensors to original indexing - indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + next_token_scores = self.repetition_penalty(input_ids, next_token_scores) + next_token_scores = self.top_p(next_token_scores) + next_tokens = self.sample(next_token_scores) next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) @@ -379,3 +358,31 @@ class QwenRunner: # loss.backward() return lm_logits + + def repetition_penalty(self, input_ids, next_token_scores): + penalty = self.qwen.config.repetition_penalty + score = torch.gather(next_token_scores, 1, input_ids) + # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities + score = torch.where(score < 0, score * penalty, score / penalty) + next_token_scores = next_token_scores.scatter_(1, input_ids, score) + return next_token_scores + + def top_p(self, next_token_scores): + top_p = self.qwen.config.top_p + filter_value = -float("Inf") + min_tokens_to_keep = 1 + sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # Keep at least min_tokens_to_keep + sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value) + return next_token_scores + + def sample(self, next_token_scores): + probs = nn.functional.softmax(next_token_scores, dim=-1) + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + return next_tokens diff --git a/qwen/research_attention.py b/qwen/research_attention.py index 561953e..be8ee7e 100644 --- a/qwen/research_attention.py +++ b/qwen/research_attention.py @@ -70,28 +70,44 @@ def Dump_lm_head_weight(model): # Dump_lm_head_weight(model) +def DumpQK(query, key, causal_mask, index): + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = torch.softmax(attn_weight, dim=-1) + size = query.shape[2] + attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) + attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) + qk = attn_weight * attn_mask + qk = qk[0] + prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" + show.DumpTensorToImage(qk, prePath, GridValue=255) + + class ResearchRunner(QwenRunner): + def __init__(self, model): + super().__init__(model) + self.tokenDecode = [] + def attention(self, attention, query, key, value, causal_mask): query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) - scale_factor = 1 / math.sqrt(query.size(-1)) - attn_weight = query @ key.transpose(-2, -1) * scale_factor - attn_weight = torch.softmax(attn_weight, dim=-1) - size = query.shape[2] - attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) - attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) - qk = attn_weight * attn_mask - qk = qk[0] - prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png" - show.DumpTensorToImage(qk, prePath, GridValue=255) + DumpQK(query, key, causal_mask, attention.index) attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim) attn_output = attention.c_proj(context_layer) return attn_output + def sample(self, next_token_scores): + next_tokens = super().sample(next_token_scores) + decoded, response, end_reason = decode_tokens( + next_tokens, + tokenizer, + ) + self.tokenDecode.append(decoded) + return next_tokens runner = ResearchRunner(model) @@ -106,5 +122,8 @@ print(decode_tokens) # 日本的首都东京。<|im_end|> # <|endoftext|> +show.DumpListToFile(runner.tokenDecode, "./temp/token_decode_list.txt") + + if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""": raise ()