Refine query.

This commit is contained in:
Colin 2025-08-26 15:47:03 +08:00
parent 6d0a410731
commit edb8d8103a
1 changed files with 8 additions and 5 deletions

View File

@ -78,10 +78,11 @@ def get_inference(dataset, meaning):
node.print()
def validate(model, dataloader):
def validate(model, dataloader, mask_level=[0, 1], mask_index=[0, -1]):
logger = TBLogger("./log/", name=conf.name)
logger.log_hyperparams(configuration.class_to_dict(conf))
model.eval()
lit_trainer = pl.Trainer(
accelerator="cuda",
@ -91,10 +92,10 @@ def validate(model, dataloader):
max_epochs=conf.max_epochs,
)
dataloader.dataset.meaning_dataset.set_mask([0, 1], [0, -1])
dataloader.dataset.meaning_dataset.set_mask(mask_level, mask_index)
val = lit_trainer.validate(model, dataloaders=dataloader)
print(val)
val = lit_trainer.validate(model, dataloaders=dataloader, verbose=False)
return val[0]["accuracy"]
def get_contain_meaning(dataset, meaning):
@ -151,6 +152,8 @@ def dump_dk(dataset, meaning):
# print(sorted_logits.detach().cpu().numpy())
# print(sorted_indices.detach().cpu().numpy())
qwen.llm.hook_attention = None
if __name__ == "__main__":
@ -187,4 +190,4 @@ if __name__ == "__main__":
dump_dk(md, meaning)
validate(qwen, val_loader)
print("validate accuracy: " + str(validate(qwen, val_loader)))