import torch from model.light_module import LightModule from model.light_module import ModelRunner from model.modeling_wit import QWenLMHeadModel import numpy as np import math import sys import os sys.path.append("..") from tools import show import configuration import meaning.dataset as ds def get_latest_file_safe(directory): try: files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))] if not files: print("警告:目录中没有文件") return None latest = max(files, key=lambda f: os.path.getmtime(os.path.join(directory, f))) return latest except Exception as e: print(f"错误: {e}") return None def get_dataset_set_freq(dataset): loader = dataset map = loader.meaning_dataset.get_meaning_map() seqs = {} for batch in loader: for m in batch["meaning"]: seqs[m] = map.get_sequence(m) while True: m = int(input("input meaning: ")) total = 0 for seq in seqs.values(): total = total + seq.count(m) print(f"meaning of {m} count as {total}") def get_inference(dataset, seq): map = dataset.get_meaning_map() node = map.get_nodetree(seq) item, l, rank_idx, rank_all = map.get_sequence(seq) print("len of seq:" + str(len(item))) for i in range(1, len(item)): itemm = [item[:i]] batch = torch.tensor([item[:i]], dtype=torch.int64) sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) next_token = sorted_indices.detach().cpu().numpy()[0][0] if item[i] != next_token: node.set_seq_prop(i, "ERR_" + str(next_token)) print("index: " + str(i) + " golden: " + str(item[i]) + " -> " + str(next_token) + " ERR") node.print() if __name__ == "__main__": log_path = "log/bigger/version_2/" file = get_latest_file_safe(log_path + "/checkpoints") checkpoint_path = log_path + "checkpoints/" + file conf = configuration.class_from_file(log_path + "conf.pkl") model = QWenLMHeadModel(conf.model_config) qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model) qwen.eval() conf = qwen.config torch.manual_seed(conf.seed) np.random.seed(conf.seed) runner = ModelRunner(qwen.llm) train, val = ds.InitDataset(conf) val = val.dataset # get_dataset_set_freq(train.dataset) md = val.meaning_dataset map = md.get_meaning_map() # seq:844 # seq:849 # seq:991 # seq:995 meaning = 991 node = map.get_nodetree(meaning) node.print() get_inference(md, meaning) def DumpQK(query, key, causal_mask, index): global relation_distance size = query.shape[2] scale_factor = 1 / math.sqrt(query.size(-1)) attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) attn_weight = attn_weight * attn_mask attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = attn_weight * attn_mask qk = attn_weight[0] prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" qk = qk.cpu() # qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0) show.DumpTensorToImage(qk, prePath) # qk_seq.append(qk) # qk_index = size qwen.llm.hook_attention = DumpQK # current_to_common, common_to_current = map.get_level_change(meaning) # print(current_to_common) # print(common_to_current) relation_distance = map.get_relation_distance(meaning) relation_distance = torch.tensor(relation_distance) rd = relation_distance.unsqueeze(0) rd = torch.cat((rd, rd, rd, rd, rd, rd), dim=0) show.DumpTensorToImage(rd, "./temp/q@k_seq_layer.png") item, level, rank_idx, rank_all = map.get_sequence(meaning) print("len of seq:" + str(len(item))) batch = torch.tensor([item], dtype=torch.int64) sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) next_token = sorted_indices.detach().cpu().numpy()[0][0] # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7]], dtype=torch.int64) # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) # print(sorted_logits.detach().cpu().numpy()) # print(sorted_indices.detach().cpu().numpy())