Witllm/wit/query_meaning_freq.py

38 lines
1.0 KiB
Python

import pytorch_lightning as pl
import torch
from model.light_module import LightModule
from model.modeling_wit import ModelRunner
from model.tokenization_qwen import QWenTokenizer
import numpy as np
import configuration
import dataset.dataset as ds
import dataset.node_tree as nt
if __name__ == "__main__":
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
train_dataloader, val_dataloader = ds.InitDataset(conf)
loader = train_dataloader.dataset
map = loader.meaning_dataset.get_meaning_map()
trees = {}
for batch in loader:
for m in batch["meaning"]:
trees[m] = map.get_tree(m)
while True:
m = int(input("input meaning: "))
total = 0
for tree in trees.values():
total = total + tree.count(m)
print(f"meaning of {m} count as {total}")