From 3e6ff2d580b5b230740e6548b3ce8d0b9159e331 Mon Sep 17 00:00:00 2001 From: Colin <> Date: Sun, 17 Aug 2025 17:58:23 +0800 Subject: [PATCH] Refine query_block_output. --- wit/configuration.py | 14 ++++++ wit/inference.py | 46 ------------------- wit/query_block_output.py | 93 ++++++++++++++++++++++++++++++++------- wit/query_meaning_freq.py | 35 --------------- 4 files changed, 90 insertions(+), 98 deletions(-) delete mode 100644 wit/inference.py delete mode 100644 wit/query_meaning_freq.py diff --git a/wit/configuration.py b/wit/configuration.py index 4fab7fd..a45aa0b 100644 --- a/wit/configuration.py +++ b/wit/configuration.py @@ -1,3 +1,6 @@ +import pickle + + class ModelConfig: def __init__(self): self.vocab_size = 4096 @@ -90,6 +93,17 @@ def class_to_dict(obj): return str(obj) +def class_to_file(obj, file): + with open(file, "wb") as file: + pickle.dump(obj, file) + + +def class_from_file(file): + with open(file, "rb") as file: + obj = pickle.load(file) + return obj + + # train_config = TrainConfig() # train_config_dict = class_to_dict(train_config) # import pprint diff --git a/wit/inference.py b/wit/inference.py deleted file mode 100644 index 3449500..0000000 --- a/wit/inference.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - -from model.light_module import LightModule -from model.light_module import ModelRunner -import numpy as np - -import meaning.dataset as ds - -if __name__ == "__main__": - - # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt" - # checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt" - checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt" - - qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) - qwen.eval() - conf = qwen.config - torch.manual_seed(conf.seed) - np.random.seed(conf.seed) - torch.cuda.manual_seed_all(conf.seed) - - runner = ModelRunner(qwen.llm) - - _, val = ds.InitDataset(conf).dataset - md = val.meaning_dataset - map = md.get_meaning_map() - - # seq:844 - # seq:849 - # seq:991 - # seq:995 - seq = 995 - - node = map.get_nodetree(seq) - item, l, rank_idx, rank_all = map.get_sequence(seq) - print("len of seq:" + str(len(item))) - - for i in range(1, len(item)): - itemm = [item[:i]] - batch = torch.tensor([item[:i]], dtype=torch.int64) - sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) - next_token = sorted_indices.detach().cpu().numpy()[0][0] - if item[i] != next_token: - node.set_seq_prop(i, "ERR_" + str(next_token)) - print(str(item[i]) + " " + str(next_token) + " ERROR") - node.print() diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 397f9df..819958e 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -2,28 +2,100 @@ import torch from model.light_module import LightModule from model.light_module import ModelRunner +from model.modeling_wit import QWenLMHeadModel + import numpy as np import math import sys +import os sys.path.append("..") from tools import show - +import configuration import meaning.dataset as ds + +def get_latest_file_safe(directory): + try: + files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))] + if not files: + print("警告:目录中没有文件") + return None + latest = max(files, key=lambda f: os.path.getmtime(os.path.join(directory, f))) + return latest + except Exception as e: + print(f"错误: {e}") + return None + + +def get_dataset_set_freq(dataset): + loader = dataset + + map = loader.meaning_dataset.get_meaning_map() + seqs = {} + for batch in loader: + for m in batch["meaning"]: + seqs[m] = map.get_sequence(m) + while True: + m = int(input("input meaning: ")) + total = 0 + for seq in seqs.values(): + total = total + seq.count(m) + print(f"meaning of {m} count as {total}") + + +def get_inference(dataset, seq): + map = dataset.get_meaning_map() + + node = map.get_nodetree(seq) + item, l, rank_idx, rank_all = map.get_sequence(seq) + print("len of seq:" + str(len(item))) + + for i in range(1, len(item)): + itemm = [item[:i]] + batch = torch.tensor([item[:i]], dtype=torch.int64) + sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) + next_token = sorted_indices.detach().cpu().numpy()[0][0] + if item[i] != next_token: + node.set_seq_prop(i, "ERR_" + str(next_token)) + print(str(item[i]) + " " + str(next_token) + " ERROR") + node.print() + + if __name__ == "__main__": - checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt" + log_path = "log/bigger/version_1/" - qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) + file = get_latest_file_safe(log_path + "/checkpoints") + checkpoint_path = log_path + "checkpoints/" + file + conf = configuration.class_from_file(log_path + "conf.pkl") + model = QWenLMHeadModel(conf.model_config) + qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model) qwen.eval() conf = qwen.config torch.manual_seed(conf.seed) np.random.seed(conf.seed) runner = ModelRunner(qwen.llm) + train, val = ds.InitDataset(conf) + val = val.dataset + # get_dataset_set_freq(train.dataset) + md = val.meaning_dataset + map = md.get_meaning_map() + + # seq:844 + # seq:849 + # seq:991 + # seq:995 + meaning = 995 + + get_inference(md, meaning) + + node = map.get_nodetree(meaning) + node.print() + def DumpQK(query, key, causal_mask, index): global relation_distance size = query.shape[2] @@ -37,26 +109,13 @@ if __name__ == "__main__": qk = attn_weight[0] prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" qk = qk.cpu() - qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0) + # qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0) show.DumpTensorToImage(qk, prePath) # qk_seq.append(qk) # qk_index = size qwen.llm.hook_attention = DumpQK - _, val = ds.InitDataset(conf).dataset - md = val.meaning_dataset - map = md.get_meaning_map() - - # seq:844 - # seq:849 - # seq:991 - # seq:995 - meaning = 995 - - node = map.get_nodetree(meaning) - node.print() - # current_to_common, common_to_current = map.get_level_change(meaning) # print(current_to_common) # print(common_to_current) diff --git a/wit/query_meaning_freq.py b/wit/query_meaning_freq.py deleted file mode 100644 index 855f75a..0000000 --- a/wit/query_meaning_freq.py +++ /dev/null @@ -1,35 +0,0 @@ -import pytorch_lightning as pl -import torch - -from model.light_module import LightModule -from model.tokenization_qwen import QWenTokenizer -import numpy as np - -import configuration -import meaning as m - -if __name__ == "__main__": - - checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" - - qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) - qwen.eval() - conf = qwen.config - torch.manual_seed(conf.seed) - np.random.seed(conf.seed) - - train_dataloader, val_dataloader = m.InitDataset(conf) - - loader = train_dataloader.dataset - - map = loader.meaning_dataset.get_meaning_map() - seqs = {} - for batch in loader: - for m in batch["meaning"]: - seqs[m] = map.get_sequence(m) - while True: - m = int(input("input meaning: ")) - total = 0 - for seq in seqs.values(): - total = total + seq.count(m) - print(f"meaning of {m} count as {total}")