Refine query.
This commit is contained in:
parent
6d0a410731
commit
edb8d8103a
|
@ -78,10 +78,11 @@ def get_inference(dataset, meaning):
|
||||||
node.print()
|
node.print()
|
||||||
|
|
||||||
|
|
||||||
def validate(model, dataloader):
|
def validate(model, dataloader, mask_level=[0, 1], mask_index=[0, -1]):
|
||||||
|
|
||||||
logger = TBLogger("./log/", name=conf.name)
|
logger = TBLogger("./log/", name=conf.name)
|
||||||
logger.log_hyperparams(configuration.class_to_dict(conf))
|
logger.log_hyperparams(configuration.class_to_dict(conf))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
lit_trainer = pl.Trainer(
|
lit_trainer = pl.Trainer(
|
||||||
accelerator="cuda",
|
accelerator="cuda",
|
||||||
|
@ -91,10 +92,10 @@ def validate(model, dataloader):
|
||||||
max_epochs=conf.max_epochs,
|
max_epochs=conf.max_epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader.dataset.meaning_dataset.set_mask([0, 1], [0, -1])
|
dataloader.dataset.meaning_dataset.set_mask(mask_level, mask_index)
|
||||||
|
|
||||||
val = lit_trainer.validate(model, dataloaders=dataloader)
|
val = lit_trainer.validate(model, dataloaders=dataloader, verbose=False)
|
||||||
print(val)
|
return val[0]["accuracy"]
|
||||||
|
|
||||||
|
|
||||||
def get_contain_meaning(dataset, meaning):
|
def get_contain_meaning(dataset, meaning):
|
||||||
|
@ -151,6 +152,8 @@ def dump_dk(dataset, meaning):
|
||||||
# print(sorted_logits.detach().cpu().numpy())
|
# print(sorted_logits.detach().cpu().numpy())
|
||||||
# print(sorted_indices.detach().cpu().numpy())
|
# print(sorted_indices.detach().cpu().numpy())
|
||||||
|
|
||||||
|
qwen.llm.hook_attention = None
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
@ -187,4 +190,4 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
dump_dk(md, meaning)
|
dump_dk(md, meaning)
|
||||||
|
|
||||||
validate(qwen, val_loader)
|
print("validate accuracy: " + str(validate(qwen, val_loader)))
|
||||||
|
|
Loading…
Reference in New Issue