Witllm/wit/inference.py

48 lines
1.7 KiB
Python
Raw Normal View History

2024-03-20 22:27:28 +08:00
import torch
2025-03-18 15:58:08 +08:00
from model.light_module import LightModule
from model.light_module import ModelRunner
2025-02-21 17:28:21 +08:00
import numpy as np
2024-03-20 22:27:28 +08:00
2025-02-21 15:51:27 +08:00
import dataset.dataset as ds
2024-03-20 22:27:28 +08:00
if __name__ == "__main__":
2025-02-26 16:55:20 +08:00
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
2025-02-28 13:16:39 +08:00
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
2025-03-03 14:53:15 +08:00
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
2025-03-18 15:58:08 +08:00
# checkpoint_path = "log/bigger/version_6/checkpoints/epoch=43-step=217184.ckpt"
2024-03-20 22:27:28 +08:00
2025-03-10 19:14:47 +08:00
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
2025-02-21 15:51:27 +08:00
qwen.eval()
2025-02-24 21:38:31 +08:00
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
2025-03-14 17:38:24 +08:00
torch.cuda.manual_seed_all(conf.seed)
2025-02-28 13:16:39 +08:00
runner = ModelRunner(qwen.llm)
2024-03-20 22:27:28 +08:00
2025-02-26 16:55:20 +08:00
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
# print(sorted_logits.detach().cpu().numpy())
# print(sorted_indices.detach().cpu().numpy())
2025-02-24 21:38:31 +08:00
2025-02-21 17:28:21 +08:00
val = ds.InitValDataset(conf).dataset
2025-02-22 16:50:16 +08:00
md = val.meaning_dataset
map = md.get_meaning_map()
item = md.get_token(0)
2025-02-26 16:55:20 +08:00
node = map.get_nodetree(md.get_meaning(0))
2025-02-24 21:38:31 +08:00
# node.print()
for i in range(1, len(item)):
itemm = [item[:i]]
batch = torch.tensor([item[:i]], dtype=torch.int64)
2025-02-26 16:55:20 +08:00
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
next_token = sorted_indices.detach().cpu().numpy()[0][0]
2025-02-24 21:38:31 +08:00
if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR")
2025-03-10 16:26:53 +08:00
node.print()