Witllm/wit/inference.py

46 lines
1.4 KiB
Python

import torch
from model.light_module import LightModule
from model.light_module import ModelRunner
import numpy as np
import dataset.dataset as ds
if __name__ == "__main__":
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
# checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt"
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
torch.cuda.manual_seed_all(conf.seed)
runner = ModelRunner(qwen.llm)
val = ds.InitValDataset(conf).dataset
md = val.meaning_dataset
map = md.get_meaning_map()
# seq:844
# seq:849
# seq:991
# seq:995
node = map.get_nodetree(995)
item, l, rank_idx, rank_all = map.get_sequence(995)
print("len of seq:" + str(len(item)))
for i in range(1, len(item)):
itemm = [item[:i]]
batch = torch.tensor([item[:i]], dtype=torch.int64)
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
next_token = sorted_indices.detach().cpu().numpy()[0][0]
if item[i] != next_token:
node.set_seq_prop(i, "ERR_" + str(next_token))
print(str(item[i]) + " " + str(next_token) + " ERROR")
node.print()