Add validate.

This commit is contained in:
Colin 2025-08-26 14:07:48 +08:00
parent e048d0c531
commit 6d0a410731
1 changed files with 28 additions and 10 deletions

View File

@ -5,6 +5,8 @@ from model.light_module import ModelRunner
from model.modeling_wit import QWenLMHeadModel from model.modeling_wit import QWenLMHeadModel
import numpy as np import numpy as np
import pytorch_lightning as pl
from logger import TBLogger
import math import math
import sys import sys
@ -76,6 +78,25 @@ def get_inference(dataset, meaning):
node.print() node.print()
def validate(model, dataloader):
logger = TBLogger("./log/", name=conf.name)
logger.log_hyperparams(configuration.class_to_dict(conf))
lit_trainer = pl.Trainer(
accelerator="cuda",
precision=conf.precision,
logger=logger,
strategy=conf.strategy,
max_epochs=conf.max_epochs,
)
dataloader.dataset.meaning_dataset.set_mask([0, 1], [0, -1])
val = lit_trainer.validate(model, dataloaders=dataloader)
print(val)
def get_contain_meaning(dataset, meaning): def get_contain_meaning(dataset, meaning):
map = dataset.get_meaning_map() map = dataset.get_meaning_map()
seqs = map.get_maps(contain=meaning) seqs = map.get_maps(contain=meaning)
@ -133,7 +154,7 @@ def dump_dk(dataset, meaning):
if __name__ == "__main__": if __name__ == "__main__":
log_path = "log/bigger/version_0/" log_path = "log/bigger/version_1/"
file = get_latest_file_safe(log_path + "/checkpoints") file = get_latest_file_safe(log_path + "/checkpoints")
checkpoint_path = log_path + "checkpoints/" + file checkpoint_path = log_path + "checkpoints/" + file
@ -146,18 +167,13 @@ if __name__ == "__main__":
np.random.seed(conf.seed) np.random.seed(conf.seed)
runner = ModelRunner(qwen.llm) runner = ModelRunner(qwen.llm)
train, val = ds.InitDataset(conf) train_loader, val_loader = ds.InitDataset(conf)
val = val.dataset val = val_loader.dataset
# get_dataset_set_freq(train.dataset)
md = val.meaning_dataset md = val.meaning_dataset
map = md.get_meaning_map() map = md.get_meaning_map()
meaning = 2995 # 844 849 995 991
# seq:844 # get_dataset_set_freq(train_loader.dataset)
# seq:849
# seq:991
# seq:995
meaning = 995
meaning = 2995
# node = map.get_nodetree(meaning) # node = map.get_nodetree(meaning)
# node.print() # node.print()
@ -170,3 +186,5 @@ if __name__ == "__main__":
get_contain_meaning(md, meaning) get_contain_meaning(md, meaning)
dump_dk(md, meaning) dump_dk(md, meaning)
validate(qwen, val_loader)