import pytorch_lightning as pl import torch from model.light_module import LightModule from model.modeling_wit import ModelRunner from model.tokenization_qwen import QWenTokenizer import numpy as np import configuration import dataset.dataset as ds import dataset.node_tree as nt if __name__ == "__main__": checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() conf = qwen.config torch.manual_seed(conf.seed) np.random.seed(conf.seed) train_dataloader, val_dataloader = ds.InitDataset(conf) loader = train_dataloader.dataset map = loader.meaning_dataset.get_meaning_map() trees = {} for batch in loader: for m in batch["meaning"]: trees[m] = map.get_tree(m) while True: m = int(input("input meaning: ")) total = 0 for tree in trees.values(): total = total + tree.count(m) print(f"meaning of {m} count as {total}")