Witllm/wit/inference.py

45 lines
1.3 KiB
Python
Raw Normal View History

2024-03-20 22:27:28 +08:00
import pytorch_lightning as pl
import torch
2025-02-21 15:51:27 +08:00
from model.qwen_module import QwenModule
from model.modeling_wit import QwenRunner
from model.tokenization_qwen import QWenTokenizer
2025-02-21 17:28:21 +08:00
import numpy as np
2024-03-20 22:27:28 +08:00
2025-02-21 15:51:27 +08:00
import configuration
import dataset.dataset as ds
2025-02-22 16:50:16 +08:00
import dataset.node_tree as nt
2024-03-20 22:27:28 +08:00
if __name__ == "__main__":
2025-02-24 21:38:31 +08:00
checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
2024-03-20 22:27:28 +08:00
2025-02-21 17:28:21 +08:00
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
2025-02-21 15:51:27 +08:00
qwen.eval()
2025-02-24 21:38:31 +08:00
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
2025-02-21 15:51:27 +08:00
runner = QwenRunner(qwen.llm)
2024-03-20 22:27:28 +08:00
2025-02-24 21:38:31 +08:00
# batch = torch.tensor([[41]], dtype=torch.int64)
# print(runner.ChatTokens(batch).detach().cpu().numpy()[0])
2025-02-21 17:28:21 +08:00
val = ds.InitValDataset(conf).dataset
2025-02-22 16:50:16 +08:00
md = val.meaning_dataset
2024-03-20 22:27:28 +08:00
2025-02-22 16:50:16 +08:00
map = md.get_meaning_map()
item = md.get_token(0)
2025-02-24 21:38:31 +08:00
node = map.get_tree(md.get_meaning(0))
# node.print()
for i in range(1, len(item)):
itemm = [item[:i]]
batch = torch.tensor([item[:i]], dtype=torch.int64)
next_token = runner.ChatTokens(batch, sample=False).detach().cpu().numpy()[0]
if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR")
node.print()