import torch import sys # from modelscope import snapshot_download from transformers import AutoTokenizer from transformers import AutoConfig from modeling_qwen import QWenLMHeadModel from modeling_qwen import QwenRunner from qwen_generation_utils import ( make_context, decode_tokens, ) import torch.nn.functional as F sys.path.append("..") from tools import show seed = 4321 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # model_dir = snapshot_download("qwen/Qwen-1_8B-Chat") model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat" config, kwargs = AutoConfig.from_pretrained( "./", return_unused_kwargs=True, trust_remote_code=True, code_revision=None, _commit_hash=None, ) model = QWenLMHeadModel(config) print(model) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) model = model.from_pretrained(model_dir) if torch.cuda.device_count() > 0: model = model.cuda() model = model.eval() class ResearchRunner(QwenRunner): def __init__(self, model): super().__init__(model) def forwardQWen( self, input_ids=None, labels=None, ): transfm = self.qwen.transformer input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = transfm.wte(input_ids) kv_seq_len = hidden_states.size()[1] transfm.update_rotary_pos_emb_cache(kv_seq_len, ntk_alpha=1.0) cos, sin = transfm._rotary_pos_emb_cache rotary_pos_emb_list = [[cos[:, :kv_seq_len], sin[:, :kv_seq_len]]] hidden_states = transfm.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) for block in transfm.h: self.forwardQWenBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list) break def forwardQWenBlock( self, block, hidden_states, rotary_pos_emb_list=None, ): layernorm_output = block.ln_1(hidden_states) self.forwardAttention(block.attn, layernorm_output, rotary_pos_emb_list) def attention(self, attention, query, key, value, causal_mask): query = query.permute(0, 2, 1, 3) key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3) global q global k query = query[:, head_group_index, :, :] key = key[:, head_group_index, :, :] q = torch.cat([q, query], 1) k = torch.cat([k, key], 1) head_group_index = 0 total_token = 151851 topk = 10 tokens_str = [] for token in range(total_token): decoded, response, end_reason = decode_tokens( [token], tokenizer, raw_text_len=0, context_length=0, errors="replace", ) tokens_str.append(repr(decoded)) patch_end = list(range(0, total_token, 1000)) patch_end = patch_end[1:] + [total_token] patch_start = 0 q = torch.zeros((1, 0, 128), dtype=float).to(next(model.parameters()).device) k = torch.zeros((1, 0, 128), dtype=float).to(next(model.parameters()).device) for end in patch_end: tokens = list(range(patch_start, end)) patch_start = end input_ids = torch.tensor([tokens]).to(next(model.parameters()).device) runner = ResearchRunner(model) runner.forwardQWen(input_ids) q = q[0, :, :] k = k[0, :, :].permute(1, 0) token_topk = [] for i in range(total_token): subq = q[i, :] qk = subq @ k values, indices = torch.topk(qk, topk) item = str(i).zfill(7) + " " + tokens_str[i] + " : " for index in indices: item += tokens_str[index] + " " token_topk.append(item) show.DumpListToFile(token_topk, "./temp/qwen_token_qk_topk_head_group_" + str(head_group_index) + ".txt") print("decoded")