Add validate.
This commit is contained in:
parent
e048d0c531
commit
6d0a410731
|
@ -5,6 +5,8 @@ from model.light_module import ModelRunner
|
|||
from model.modeling_wit import QWenLMHeadModel
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
from logger import TBLogger
|
||||
|
||||
import math
|
||||
import sys
|
||||
|
@ -76,6 +78,25 @@ def get_inference(dataset, meaning):
|
|||
node.print()
|
||||
|
||||
|
||||
def validate(model, dataloader):
|
||||
|
||||
logger = TBLogger("./log/", name=conf.name)
|
||||
logger.log_hyperparams(configuration.class_to_dict(conf))
|
||||
|
||||
lit_trainer = pl.Trainer(
|
||||
accelerator="cuda",
|
||||
precision=conf.precision,
|
||||
logger=logger,
|
||||
strategy=conf.strategy,
|
||||
max_epochs=conf.max_epochs,
|
||||
)
|
||||
|
||||
dataloader.dataset.meaning_dataset.set_mask([0, 1], [0, -1])
|
||||
|
||||
val = lit_trainer.validate(model, dataloaders=dataloader)
|
||||
print(val)
|
||||
|
||||
|
||||
def get_contain_meaning(dataset, meaning):
|
||||
map = dataset.get_meaning_map()
|
||||
seqs = map.get_maps(contain=meaning)
|
||||
|
@ -133,7 +154,7 @@ def dump_dk(dataset, meaning):
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
log_path = "log/bigger/version_0/"
|
||||
log_path = "log/bigger/version_1/"
|
||||
|
||||
file = get_latest_file_safe(log_path + "/checkpoints")
|
||||
checkpoint_path = log_path + "checkpoints/" + file
|
||||
|
@ -146,18 +167,13 @@ if __name__ == "__main__":
|
|||
np.random.seed(conf.seed)
|
||||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
train, val = ds.InitDataset(conf)
|
||||
val = val.dataset
|
||||
# get_dataset_set_freq(train.dataset)
|
||||
train_loader, val_loader = ds.InitDataset(conf)
|
||||
val = val_loader.dataset
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
meaning = 2995 # 844 849 995 991
|
||||
|
||||
# seq:844
|
||||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
meaning = 995
|
||||
meaning = 2995
|
||||
# get_dataset_set_freq(train_loader.dataset)
|
||||
|
||||
# node = map.get_nodetree(meaning)
|
||||
# node.print()
|
||||
|
@ -170,3 +186,5 @@ if __name__ == "__main__":
|
|||
get_contain_meaning(md, meaning)
|
||||
|
||||
dump_dk(md, meaning)
|
||||
|
||||
validate(qwen, val_loader)
|
||||
|
|
Loading…
Reference in New Issue