38 lines
1.0 KiB
Python
38 lines
1.0 KiB
Python
import pytorch_lightning as pl
|
|
import torch
|
|
|
|
from model.light_module import LightModule
|
|
from model.modeling_wit import ModelRunner
|
|
from model.tokenization_qwen import QWenTokenizer
|
|
import numpy as np
|
|
|
|
import configuration
|
|
import dataset.dataset as ds
|
|
import dataset.node_tree as nt
|
|
|
|
if __name__ == "__main__":
|
|
|
|
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
|
|
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
|
qwen.eval()
|
|
conf = qwen.config
|
|
torch.manual_seed(conf.seed)
|
|
np.random.seed(conf.seed)
|
|
|
|
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
|
|
|
loader = train_dataloader.dataset
|
|
|
|
map = loader.meaning_dataset.get_meaning_map()
|
|
trees = {}
|
|
for batch in loader:
|
|
for m in batch["meaning"]:
|
|
trees[m] = map.get_tree(m)
|
|
while True:
|
|
m = int(input("input meaning: "))
|
|
total = 0
|
|
for tree in trees.values():
|
|
total = total + tree.count(m)
|
|
print(f"meaning of {m} count as {total}")
|