48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
import torch
|
|
|
|
from model.light_module import LightModule
|
|
from model.light_module import ModelRunner
|
|
import numpy as np
|
|
|
|
import dataset.dataset as ds
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=19-step=98720.ckpt"
|
|
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
|
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
|
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
|
# checkpoint_path = "log/bigger/version_6/checkpoints/epoch=43-step=217184.ckpt"
|
|
|
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
|
qwen.eval()
|
|
conf = qwen.config
|
|
torch.manual_seed(conf.seed)
|
|
np.random.seed(conf.seed)
|
|
torch.cuda.manual_seed_all(conf.seed)
|
|
|
|
runner = ModelRunner(qwen.llm)
|
|
|
|
# batch = torch.tensor([[11, 0, 3, 7, 15, 8, 10, 7, 14, 13, 1, 12, 13]], dtype=torch.int64)
|
|
# sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
|
# print(sorted_logits.detach().cpu().numpy())
|
|
# print(sorted_indices.detach().cpu().numpy())
|
|
|
|
val = ds.InitValDataset(conf).dataset
|
|
md = val.meaning_dataset
|
|
map = md.get_meaning_map()
|
|
item = md.get_token(0)
|
|
|
|
node = map.get_nodetree(md.get_meaning(0))
|
|
# node.print()
|
|
|
|
for i in range(1, len(item)):
|
|
itemm = [item[:i]]
|
|
batch = torch.tensor([item[:i]], dtype=torch.int64)
|
|
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
|
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
|
if item[i] != next_token:
|
|
node.set_seq_prop(i, "ERR_" + str(next_token))
|
|
print(str(item[i]) + " " + str(next_token) + " ERROR")
|
|
node.print()
|