From b7c27af6c850daec94acf233a10844bbfcf68f7b Mon Sep 17 00:00:00 2001 From: Colin Date: Mon, 29 Jan 2024 00:12:08 +0800 Subject: [PATCH] Add research_token to dump token relationship in attention layer0. --- qwen/research_attention.py | 2 +- qwen/research_token.py | 136 +++++++++++++++++++++++++++++++++++++ 2 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 qwen/research_token.py diff --git a/qwen/research_attention.py b/qwen/research_attention.py index 87fc294..37dd818 100644 --- a/qwen/research_attention.py +++ b/qwen/research_attention.py @@ -47,7 +47,7 @@ model = model.eval() def Dump_tokens_list(model): tokens = [] - for token in range(config.eos_token_id): + for token in range(151851): decoded, response, end_reason = decode_tokens( [token], tokenizer, diff --git a/qwen/research_token.py b/qwen/research_token.py new file mode 100644 index 0000000..fac6bf3 --- /dev/null +++ b/qwen/research_token.py @@ -0,0 +1,136 @@ +import torch +import sys + +# from modelscope import snapshot_download +from transformers import AutoTokenizer +from transformers import AutoConfig + +from modeling_qwen import QWenLMHeadModel +from modeling_qwen import QwenRunner + + +from qwen_generation_utils import ( + make_context, + decode_tokens, +) + +import torch.nn.functional as F + +sys.path.append("..") +from tools import show + +seed = 4321 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) + +# model_dir = snapshot_download("qwen/Qwen-1_8B-Chat") +model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat" + +config, kwargs = AutoConfig.from_pretrained( + "./", + return_unused_kwargs=True, + trust_remote_code=True, + code_revision=None, + _commit_hash=None, +) + +model = QWenLMHeadModel(config) +print(model) + +tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + +model = model.from_pretrained(model_dir) +if torch.cuda.device_count() > 0: + model = model.cuda() +model = model.eval() + + +class ResearchRunner(QwenRunner): + def __init__(self, model): + super().__init__(model) + + def forwardQWen( + self, + input_ids=None, + labels=None, + ): + transfm = self.qwen.transformer + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + hidden_states = transfm.wte(input_ids) + kv_seq_len = hidden_states.size()[1] + + transfm.update_rotary_pos_emb_cache(kv_seq_len, ntk_alpha=1.0) + cos, sin = transfm._rotary_pos_emb_cache + rotary_pos_emb_list = [[cos[:, :kv_seq_len], sin[:, :kv_seq_len]]] + + hidden_states = transfm.drop(hidden_states) + output_shape = input_shape + (hidden_states.size(-1),) + + for block in transfm.h: + self.forwardQWenBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list) + break + + def forwardQWenBlock( + self, + block, + hidden_states, + rotary_pos_emb_list=None, + ): + layernorm_output = block.ln_1(hidden_states) + self.forwardAttention(block.attn, layernorm_output, rotary_pos_emb_list) + + def attention(self, attention, query, key, value, causal_mask): + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + global q + global k + query = query[:, head_group_index, :, :] + key = key[:, head_group_index, :, :] + q = torch.cat([q, query], 1) + k = torch.cat([k, key], 1) + + +head_group_index = 0 +total_token = 151851 +topk = 10 + +tokens_str = [] +for token in range(total_token): + decoded, response, end_reason = decode_tokens( + [token], + tokenizer, + raw_text_len=0, + context_length=0, + errors="replace", + ) + tokens_str.append(repr(decoded)) + +patch_end = list(range(0, total_token, 1000)) +patch_end = patch_end[1:] + [total_token] +patch_start = 0 +q = torch.zeros((1, 0, 128), dtype=float).to(next(model.parameters()).device) +k = torch.zeros((1, 0, 128), dtype=float).to(next(model.parameters()).device) +for end in patch_end: + tokens = list(range(patch_start, end)) + patch_start = end + input_ids = torch.tensor([tokens]).to(next(model.parameters()).device) + runner = ResearchRunner(model) + runner.forwardQWen(input_ids) + +q = q[0, :, :] +k = k[0, :, :].permute(1, 0) +token_topk = [] +for i in range(total_token): + subq = q[i, :] + qk = subq @ k + values, indices = torch.topk(qk, topk) + item = str(i).zfill(7) + " " + tokens_str[i] + " : " + for index in indices: + item += tokens_str[index] + " " + token_topk.append(item) + +show.DumpListToFile(token_topk, "./temp/qwen_token_qk_topk_head_group_" + str(head_group_index) + ".txt") + +print("decoded")