Refine research_attention and forward model.
This commit is contained in:
		
							parent
							
								
									1811b9611a
								
							
						
					
					
						commit
						11af10e710
					
				|  | @ -209,30 +209,9 @@ class QwenRunner: | ||||||
|             outputs = self.forwardQWen(input_ids) |             outputs = self.forwardQWen(input_ids) | ||||||
|             next_token_scores = outputs[:, -1, :] |             next_token_scores = outputs[:, -1, :] | ||||||
| 
 | 
 | ||||||
|             # repetition_penalty |             next_token_scores = self.repetition_penalty(input_ids, next_token_scores) | ||||||
|             penalty = qwen.config.repetition_penalty |             next_token_scores = self.top_p(next_token_scores) | ||||||
|             score = torch.gather(next_token_scores, 1, input_ids) |             next_tokens = self.sample(next_token_scores) | ||||||
|             # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities |  | ||||||
|             score = torch.where(score < 0, score * penalty, score / penalty) |  | ||||||
|             next_token_scores = next_token_scores.scatter_(1, input_ids, score) |  | ||||||
| 
 |  | ||||||
|             # top_p |  | ||||||
|             top_p = qwen.config.top_p |  | ||||||
|             filter_value = -float("Inf") |  | ||||||
|             min_tokens_to_keep = 1 |  | ||||||
|             sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False) |  | ||||||
|             cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |  | ||||||
|             # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) |  | ||||||
|             sorted_indices_to_remove = cumulative_probs <= (1 - top_p) |  | ||||||
|             # Keep at least min_tokens_to_keep |  | ||||||
|             sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 |  | ||||||
|             # scatter sorted tensors to original indexing |  | ||||||
|             indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |  | ||||||
|             next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value) |  | ||||||
| 
 |  | ||||||
|             # sample |  | ||||||
|             probs = nn.functional.softmax(next_token_scores, dim=-1) |  | ||||||
|             next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) |  | ||||||
| 
 | 
 | ||||||
|             next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |             next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) | ||||||
|             input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |             input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | ||||||
|  | @ -379,3 +358,31 @@ class QwenRunner: | ||||||
|         # loss.backward() |         # loss.backward() | ||||||
| 
 | 
 | ||||||
|         return lm_logits |         return lm_logits | ||||||
|  | 
 | ||||||
|  |     def repetition_penalty(self, input_ids, next_token_scores): | ||||||
|  |         penalty = self.qwen.config.repetition_penalty | ||||||
|  |         score = torch.gather(next_token_scores, 1, input_ids) | ||||||
|  |         # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities | ||||||
|  |         score = torch.where(score < 0, score * penalty, score / penalty) | ||||||
|  |         next_token_scores = next_token_scores.scatter_(1, input_ids, score) | ||||||
|  |         return next_token_scores | ||||||
|  | 
 | ||||||
|  |     def top_p(self, next_token_scores): | ||||||
|  |         top_p = self.qwen.config.top_p | ||||||
|  |         filter_value = -float("Inf") | ||||||
|  |         min_tokens_to_keep = 1 | ||||||
|  |         sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False) | ||||||
|  |         cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | ||||||
|  |         # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | ||||||
|  |         sorted_indices_to_remove = cumulative_probs <= (1 - top_p) | ||||||
|  |         # Keep at least min_tokens_to_keep | ||||||
|  |         sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0 | ||||||
|  |         # scatter sorted tensors to original indexing | ||||||
|  |         indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | ||||||
|  |         next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value) | ||||||
|  |         return next_token_scores | ||||||
|  | 
 | ||||||
|  |     def sample(self, next_token_scores): | ||||||
|  |         probs = nn.functional.softmax(next_token_scores, dim=-1) | ||||||
|  |         next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) | ||||||
|  |         return next_tokens | ||||||
|  |  | ||||||
|  | @ -70,28 +70,44 @@ def Dump_lm_head_weight(model): | ||||||
| # Dump_lm_head_weight(model) | # Dump_lm_head_weight(model) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def DumpQK(query, key, causal_mask, index): | ||||||
|  |     scale_factor = 1 / math.sqrt(query.size(-1)) | ||||||
|  |     attn_weight = query @ key.transpose(-2, -1) * scale_factor | ||||||
|  |     attn_weight = torch.softmax(attn_weight, dim=-1) | ||||||
|  |     size = query.shape[2] | ||||||
|  |     attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) | ||||||
|  |     attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) | ||||||
|  |     qk = attn_weight * attn_mask | ||||||
|  |     qk = qk[0] | ||||||
|  |     prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" | ||||||
|  |     show.DumpTensorToImage(qk, prePath, GridValue=255) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class ResearchRunner(QwenRunner): | class ResearchRunner(QwenRunner): | ||||||
|  |     def __init__(self, model): | ||||||
|  |         super().__init__(model) | ||||||
|  |         self.tokenDecode = [] | ||||||
|  | 
 | ||||||
|     def attention(self, attention, query, key, value, causal_mask): |     def attention(self, attention, query, key, value, causal_mask): | ||||||
|         query = query.permute(0, 2, 1, 3) |         query = query.permute(0, 2, 1, 3) | ||||||
|         key = key.permute(0, 2, 1, 3) |         key = key.permute(0, 2, 1, 3) | ||||||
|         value = value.permute(0, 2, 1, 3) |         value = value.permute(0, 2, 1, 3) | ||||||
| 
 | 
 | ||||||
|         scale_factor = 1 / math.sqrt(query.size(-1)) |         DumpQK(query, key, causal_mask, attention.index) | ||||||
|         attn_weight = query @ key.transpose(-2, -1) * scale_factor |  | ||||||
|         attn_weight = torch.softmax(attn_weight, dim=-1) |  | ||||||
|         size = query.shape[2] |  | ||||||
|         attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) |  | ||||||
|         attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) |  | ||||||
|         qk = attn_weight * attn_mask |  | ||||||
|         qk = qk[0] |  | ||||||
|         prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png" |  | ||||||
|         show.DumpTensorToImage(qk, prePath, GridValue=255) |  | ||||||
| 
 | 
 | ||||||
|         attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) |         attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) | ||||||
|         context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim) |         context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim) | ||||||
|         attn_output = attention.c_proj(context_layer) |         attn_output = attention.c_proj(context_layer) | ||||||
|         return attn_output |         return attn_output | ||||||
| 
 | 
 | ||||||
|  |     def sample(self, next_token_scores): | ||||||
|  |         next_tokens = super().sample(next_token_scores) | ||||||
|  |         decoded, response, end_reason = decode_tokens( | ||||||
|  |             next_tokens, | ||||||
|  |             tokenizer, | ||||||
|  |         ) | ||||||
|  |         self.tokenDecode.append(decoded) | ||||||
|  |         return next_tokens | ||||||
| 
 | 
 | ||||||
| runner = ResearchRunner(model) | runner = ResearchRunner(model) | ||||||
| 
 | 
 | ||||||
|  | @ -106,5 +122,8 @@ print(decode_tokens) | ||||||
| # 日本的首都东京。<|im_end|> | # 日本的首都东京。<|im_end|> | ||||||
| # <|endoftext|> | # <|endoftext|> | ||||||
| 
 | 
 | ||||||
|  | show.DumpListToFile(runner.tokenDecode, "./temp/token_decode_list.txt") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""": | if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""": | ||||||
|     raise () |     raise () | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue