import pytorch_lightning as pl import torch from model.qwen_module import QwenModule from model.modeling_wit import QwenRunner from model.tokenization_qwen import QWenTokenizer import numpy as np import configuration import dataset.dataset as ds import dataset.node_tree as nt if __name__ == "__main__": # checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt" checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() conf = qwen.config torch.manual_seed(conf.seed) np.random.seed(conf.seed) runner = QwenRunner(qwen.llm) # batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64) # sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) # print(sorted_logits.detach().cpu().numpy()) # print(sorted_indices.detach().cpu().numpy()) val = ds.InitValDataset(conf).dataset md = val.meaning_dataset map = md.get_meaning_map() item = md.get_token(0) node = map.get_nodetree(md.get_meaning(0)) # node.print() for i in range(1, len(item)): itemm = [item[:i]] batch = torch.tensor([item[:i]], dtype=torch.int64) sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False) next_token = sorted_indices.detach().cpu().numpy()[0][0] if item[i] != next_token: node.set_seq_prop(i, "ERR_" + str(next_token)) print(str(item[i]) + " " + str(next_token) + " ERROR") node.print()