From 1811b9611abb70a72f54f54e08c9584e21af1d46 Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 22 Jan 2024 20:57:27 +0800 Subject: [PATCH] Refine research_attention. --- qwen/qwen_generation_utils.py | 65 ++++++++++++----------------------- qwen/research_attention.py | 42 +++++++++++++++++++--- test/tensor.py | 11 ++++-- tools/show.py | 11 +++--- tools/test.py | 4 ++- 5 files changed, 78 insertions(+), 55 deletions(-) diff --git a/qwen/qwen_generation_utils.py b/qwen/qwen_generation_utils.py index 3ed754f..cc80126 100644 --- a/qwen/qwen_generation_utils.py +++ b/qwen/qwen_generation_utils.py @@ -47,9 +47,9 @@ def get_ltor_masks_and_position_ids( att_mask_batch = micro_batch_size else: att_mask_batch = 1 - attention_mask = torch.tril( - torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) - ).view(att_mask_batch, 1, seq_length, seq_length) + attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view( + att_mask_batch, 1, seq_length, seq_length + ) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) @@ -66,7 +66,6 @@ def get_ltor_masks_and_position_ids( if reset_position_ids or reset_attention_mask: # Loop through the batches: for b in range(micro_batch_size): - # Find indecies where EOD token is. eod_index = position_ids[b, data[b] == eod_token] # Detach indecies from positions if going to modify positions. @@ -105,13 +104,14 @@ def get_batch(context_tokens: torch.LongTensor, eod_id: int): ) return tokens, attention_mask, position_ids + def make_context( tokenizer: PreTrainedTokenizer, query: str, query_assistant: str = "", history: List[Tuple[str, str]] = None, system: str = "", - max_window_size: int = 6144 + max_window_size: int = 6144, ): if history is None: history = [] @@ -122,32 +122,26 @@ def make_context( nl_tokens = tokenizer.encode("\n") def _tokenize_str(role, content): - return f"{role}\n{content}", tokenizer.encode( - role, allowed_special=set() - ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) + return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode( + content, allowed_special=set() + ) system_text, system_tokens_part = _tokenize_str("system", system) system_tokens = im_start_tokens + system_tokens_part + im_end_tokens - assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set()) + assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set()) raw_text = "" context_tokens = [] for turn_query, turn_response in reversed(history): query_text, query_tokens_part = _tokenize_str("user", turn_query) query_tokens = im_start_tokens + query_tokens_part + im_end_tokens - response_text, response_tokens_part = _tokenize_str( - "assistant", turn_response - ) + response_text, response_tokens_part = _tokenize_str("assistant", turn_response) response_tokens = im_start_tokens + response_tokens_part + im_end_tokens next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens - prev_chat = ( - f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" - ) + prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" - current_context_size = ( - len(system_tokens) + len(next_context_tokens) + len(context_tokens) - ) + current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens) if current_context_size < max_window_size: context_tokens = next_context_tokens + context_tokens raw_text = prev_chat + raw_text @@ -171,12 +165,13 @@ def make_context( return raw_text, context_tokens + def decode_tokens( tokens: Union[torch.LongTensor, TokensType], tokenizer: PreTrainedTokenizer, - raw_text_len: int, - context_length: int, - errors: str="replace", + raw_text_len: int = 0, + context_length: int = 0, + errors: str = "replace", ) -> str: if torch.is_tensor(tokens): tokens = tokens.cpu().numpy().tolist() @@ -211,42 +206,26 @@ class StopWordsLogitsProcessor(LogitsProcessor): """ def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): - if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: - raise ValueError( - f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." - ) + raise ValueError(f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.") if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): - raise ValueError( - f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." - ) + raise ValueError(f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.") if any( - any( - (not isinstance(token_id, (int, np.integer)) or token_id < 0) - for token_id in stop_word_ids - ) + any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in stop_word_ids) for stop_word_ids in stop_words_ids ): raise ValueError( f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." ) - self.stop_words_ids = list( - filter( - lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids - ) - ) + self.stop_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids)) self.eos_token_id = eos_token_id for stop_token_seq in self.stop_words_ids: - assert ( - len(stop_token_seq) > 0 - ), "Stop words token sequences {} cannot have an empty list".format( + assert len(stop_token_seq) > 0, "Stop words token sequences {} cannot have an empty list".format( stop_words_ids ) - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor - ) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: stopped_samples = self._calc_stopped_samples(input_ids) for i, should_stop in enumerate(stopped_samples): if should_stop: diff --git a/qwen/research_attention.py b/qwen/research_attention.py index f0a42d8..561953e 100644 --- a/qwen/research_attention.py +++ b/qwen/research_attention.py @@ -10,6 +10,11 @@ from modeling_qwen import QwenRunner import torch.nn.functional as F +from qwen_generation_utils import ( + make_context, + decode_tokens, +) + sys.path.append("..") from tools import show @@ -32,10 +37,39 @@ model = QWenLMHeadModel(config) print(model) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) -model = model.from_pretrained(model_dir).cuda() + +model = model.from_pretrained(model_dir) model = model.eval() +def Dump_tokens_list(model): + tokens = [] + for token in range(config.eos_token_id): + decoded, response, end_reason = decode_tokens( + [token], + tokenizer, + raw_text_len=0, + context_length=0, + errors="replace", + ) + tokens.append(str(token) + ": " + decoded) + show.DumpListToFile(tokens, "./temp/qwen_token_list.txt") + + +# Dump_tokens_list(model) + + +def Dump_lm_head_weight(model): + weight = model.lm_head.weight.cpu() # [151936,2048,] + weight = weight.reshape(64, -1, 2048) + for i in range(64): + sub = weight[i].reshape(-1, 64, 32) + show.DumpTensorToImage(sub, "./temp/lm_head_" + str(i) + "_2374_2048.png") + + +# Dump_lm_head_weight(model) + + class ResearchRunner(QwenRunner): def attention(self, attention, query, key, value, causal_mask): query = query.permute(0, 2, 1, 3) @@ -50,8 +84,8 @@ class ResearchRunner(QwenRunner): attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) qk = attn_weight * attn_mask qk = qk[0] - prePath = "./temp/" - show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png") + prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png" + show.DumpTensorToImage(qk, prePath, GridValue=255) attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2) context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim) @@ -71,6 +105,6 @@ print(decode_tokens) # <|im_start|>assistant # 日本的首都东京。<|im_end|> # <|endoftext|> - # 日本的首都是东京。<|im_end|> + if decode_tokens.split("\n")[-2] != """日本的首都东京。<|im_end|>""": raise () diff --git a/test/tensor.py b/test/tensor.py index eef0b44..a239b34 100644 --- a/test/tensor.py +++ b/test/tensor.py @@ -1,8 +1,16 @@ import torch import torch.nn.functional as F -x = torch.tensor([[1, 2], [3, 4]]) +x1 = torch.tensor([[1, 2]], dtype=float) +x2 = torch.tensor([[5, 6], [7, 8]], dtype=float) +y = x1 @ x2 # torch.matmul(x1 , x2) +x_inverse = torch.inverse(x2, out=None) +y_inverse = y @ x_inverse +y_inverse = y_inverse.permute(1, 0) + + +x = torch.tensor([[1, 2], [3, 4]], dtype=float) print(x) print("x.tile((2)) -> ", x.tile((2)).shape) print(x.tile((2))) @@ -55,4 +63,3 @@ z = torch.ones([1, 32, 6, 128]) att = torch.matmul(x, y) mm = torch.matmul(att, z) print(mm.shape) - diff --git a/tools/show.py b/tools/show.py index f83d662..9ac7bcf 100644 --- a/tools/show.py +++ b/tools/show.py @@ -8,23 +8,24 @@ import numpy as np import os -def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True): +def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True, GridValue=0): if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: raise ("Error input dims") if len(tensor.shape) == 3: channel = tensor.shape[0] x = math.ceil((channel) ** 0.5) - tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0) + tensor = F.pad(tensor, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=0) + calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2])) if AutoContrast: - calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2])) tensormax = calc.max(1)[0] tensormin = calc.min(1)[0] calc = calc.transpose(1, 0) calc = ((calc - tensormin) / (tensormax - tensormin)) * 255 calc = calc.transpose(1, 0) - tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2])) - tensor = tensor.permute((0, 2, 1, 3)) + calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2])) + calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue) + tensor = calc.permute((0, 2, 1, 3)) tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False) return diff --git a/tools/test.py b/tools/test.py index ae26f69..d1e16bd 100644 --- a/tools/test.py +++ b/tools/test.py @@ -12,7 +12,9 @@ show.DumpTensorToImage(radata, "test.png") radata = torch.randn(3, 127, 127) -show.DumpTensorToImage(radata, "test.png") +show.DumpTensorToImage(radata, "test1.png", AutoContrast=True, GridValue=0) +show.DumpTensorToImage(radata, "test2.png", AutoContrast=True, GridValue=255) +show.DumpTensorToImage(radata, "test3.png", AutoContrast=False, GridValue=0) radata = torch.randn(127, 127)