diff --git a/wit/query_block_output.py b/wit/query_block_output.py index ab8886f..003f40e 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -5,6 +5,8 @@ from model.light_module import ModelRunner from model.modeling_wit import QWenLMHeadModel import numpy as np +import pytorch_lightning as pl +from logger import TBLogger import math import sys @@ -76,6 +78,25 @@ def get_inference(dataset, meaning): node.print() +def validate(model, dataloader): + + logger = TBLogger("./log/", name=conf.name) + logger.log_hyperparams(configuration.class_to_dict(conf)) + + lit_trainer = pl.Trainer( + accelerator="cuda", + precision=conf.precision, + logger=logger, + strategy=conf.strategy, + max_epochs=conf.max_epochs, + ) + + dataloader.dataset.meaning_dataset.set_mask([0, 1], [0, -1]) + + val = lit_trainer.validate(model, dataloaders=dataloader) + print(val) + + def get_contain_meaning(dataset, meaning): map = dataset.get_meaning_map() seqs = map.get_maps(contain=meaning) @@ -133,7 +154,7 @@ def dump_dk(dataset, meaning): if __name__ == "__main__": - log_path = "log/bigger/version_0/" + log_path = "log/bigger/version_1/" file = get_latest_file_safe(log_path + "/checkpoints") checkpoint_path = log_path + "checkpoints/" + file @@ -146,18 +167,13 @@ if __name__ == "__main__": np.random.seed(conf.seed) runner = ModelRunner(qwen.llm) - train, val = ds.InitDataset(conf) - val = val.dataset - # get_dataset_set_freq(train.dataset) + train_loader, val_loader = ds.InitDataset(conf) + val = val_loader.dataset md = val.meaning_dataset map = md.get_meaning_map() + meaning = 2995 # 844 849 995 991 - # seq:844 - # seq:849 - # seq:991 - # seq:995 - meaning = 995 - meaning = 2995 + # get_dataset_set_freq(train_loader.dataset) # node = map.get_nodetree(meaning) # node.print() @@ -170,3 +186,5 @@ if __name__ == "__main__": get_contain_meaning(md, meaning) dump_dk(md, meaning) + + validate(qwen, val_loader)