Witllm/wit/query_meaning_freq.py

36 lines
958 B
Python
Raw Normal View History

2025-02-26 16:55:20 +08:00
import pytorch_lightning as pl
import torch
2025-03-18 15:58:08 +08:00
from model.light_module import LightModule
2025-02-26 16:55:20 +08:00
from model.tokenization_qwen import QWenTokenizer
import numpy as np
import configuration
2025-08-10 15:10:20 +08:00
import meaning as m
2025-02-26 16:55:20 +08:00
if __name__ == "__main__":
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
2025-03-10 19:14:47 +08:00
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
2025-02-26 16:55:20 +08:00
qwen.eval()
conf = qwen.config
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
2025-08-10 15:10:20 +08:00
train_dataloader, val_dataloader = m.InitDataset(conf)
2025-02-26 16:55:20 +08:00
loader = train_dataloader.dataset
map = loader.meaning_dataset.get_meaning_map()
seqs = {}
2025-02-26 16:55:20 +08:00
for batch in loader:
for m in batch["meaning"]:
seqs[m] = map.get_sequence(m)
2025-02-26 16:55:20 +08:00
while True:
m = int(input("input meaning: "))
total = 0
for seq in seqs.values():
total = total + seq.count(m)
2025-02-26 16:55:20 +08:00
print(f"meaning of {m} count as {total}")