Add validate.
This commit is contained in:
parent
e048d0c531
commit
6d0a410731
|
@ -5,6 +5,8 @@ from model.light_module import ModelRunner
|
||||||
from model.modeling_wit import QWenLMHeadModel
|
from model.modeling_wit import QWenLMHeadModel
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from logger import TBLogger
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
|
@ -76,6 +78,25 @@ def get_inference(dataset, meaning):
|
||||||
node.print()
|
node.print()
|
||||||
|
|
||||||
|
|
||||||
|
def validate(model, dataloader):
|
||||||
|
|
||||||
|
logger = TBLogger("./log/", name=conf.name)
|
||||||
|
logger.log_hyperparams(configuration.class_to_dict(conf))
|
||||||
|
|
||||||
|
lit_trainer = pl.Trainer(
|
||||||
|
accelerator="cuda",
|
||||||
|
precision=conf.precision,
|
||||||
|
logger=logger,
|
||||||
|
strategy=conf.strategy,
|
||||||
|
max_epochs=conf.max_epochs,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader.dataset.meaning_dataset.set_mask([0, 1], [0, -1])
|
||||||
|
|
||||||
|
val = lit_trainer.validate(model, dataloaders=dataloader)
|
||||||
|
print(val)
|
||||||
|
|
||||||
|
|
||||||
def get_contain_meaning(dataset, meaning):
|
def get_contain_meaning(dataset, meaning):
|
||||||
map = dataset.get_meaning_map()
|
map = dataset.get_meaning_map()
|
||||||
seqs = map.get_maps(contain=meaning)
|
seqs = map.get_maps(contain=meaning)
|
||||||
|
@ -133,7 +154,7 @@ def dump_dk(dataset, meaning):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
log_path = "log/bigger/version_0/"
|
log_path = "log/bigger/version_1/"
|
||||||
|
|
||||||
file = get_latest_file_safe(log_path + "/checkpoints")
|
file = get_latest_file_safe(log_path + "/checkpoints")
|
||||||
checkpoint_path = log_path + "checkpoints/" + file
|
checkpoint_path = log_path + "checkpoints/" + file
|
||||||
|
@ -146,18 +167,13 @@ if __name__ == "__main__":
|
||||||
np.random.seed(conf.seed)
|
np.random.seed(conf.seed)
|
||||||
runner = ModelRunner(qwen.llm)
|
runner = ModelRunner(qwen.llm)
|
||||||
|
|
||||||
train, val = ds.InitDataset(conf)
|
train_loader, val_loader = ds.InitDataset(conf)
|
||||||
val = val.dataset
|
val = val_loader.dataset
|
||||||
# get_dataset_set_freq(train.dataset)
|
|
||||||
md = val.meaning_dataset
|
md = val.meaning_dataset
|
||||||
map = md.get_meaning_map()
|
map = md.get_meaning_map()
|
||||||
|
meaning = 2995 # 844 849 995 991
|
||||||
|
|
||||||
# seq:844
|
# get_dataset_set_freq(train_loader.dataset)
|
||||||
# seq:849
|
|
||||||
# seq:991
|
|
||||||
# seq:995
|
|
||||||
meaning = 995
|
|
||||||
meaning = 2995
|
|
||||||
|
|
||||||
# node = map.get_nodetree(meaning)
|
# node = map.get_nodetree(meaning)
|
||||||
# node.print()
|
# node.print()
|
||||||
|
@ -170,3 +186,5 @@ if __name__ == "__main__":
|
||||||
get_contain_meaning(md, meaning)
|
get_contain_meaning(md, meaning)
|
||||||
|
|
||||||
dump_dk(md, meaning)
|
dump_dk(md, meaning)
|
||||||
|
|
||||||
|
validate(qwen, val_loader)
|
||||||
|
|
Loading…
Reference in New Issue