import pytorch_lightning as pl import torch from model.light_module import LightModule from model.modeling_wit import ModelRunner from model.tokenization_qwen import QWenTokenizer import numpy as np import configuration import dataset.dataset as ds import dataset.node_tree as nt if __name__ == "__main__": checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() conf = qwen.config torch.manual_seed(conf.seed) np.random.seed(conf.seed) train_dataloader, val_dataloader = ds.InitDataset(conf) loader = train_dataloader.dataset map = loader.meaning_dataset.get_meaning_map() seqs = {} for batch in loader: for m in batch["meaning"]: seqs[m] = map.get_sequence(m) while True: m = int(input("input meaning: ")) total = 0 for seq in seqs.values(): total = total + seq.count(m) print(f"meaning of {m} count as {total}")