Refine query.
This commit is contained in:
parent
6d0a410731
commit
edb8d8103a
|
@ -78,10 +78,11 @@ def get_inference(dataset, meaning):
|
|||
node.print()
|
||||
|
||||
|
||||
def validate(model, dataloader):
|
||||
def validate(model, dataloader, mask_level=[0, 1], mask_index=[0, -1]):
|
||||
|
||||
logger = TBLogger("./log/", name=conf.name)
|
||||
logger.log_hyperparams(configuration.class_to_dict(conf))
|
||||
model.eval()
|
||||
|
||||
lit_trainer = pl.Trainer(
|
||||
accelerator="cuda",
|
||||
|
@ -91,10 +92,10 @@ def validate(model, dataloader):
|
|||
max_epochs=conf.max_epochs,
|
||||
)
|
||||
|
||||
dataloader.dataset.meaning_dataset.set_mask([0, 1], [0, -1])
|
||||
dataloader.dataset.meaning_dataset.set_mask(mask_level, mask_index)
|
||||
|
||||
val = lit_trainer.validate(model, dataloaders=dataloader)
|
||||
print(val)
|
||||
val = lit_trainer.validate(model, dataloaders=dataloader, verbose=False)
|
||||
return val[0]["accuracy"]
|
||||
|
||||
|
||||
def get_contain_meaning(dataset, meaning):
|
||||
|
@ -151,6 +152,8 @@ def dump_dk(dataset, meaning):
|
|||
# print(sorted_logits.detach().cpu().numpy())
|
||||
# print(sorted_indices.detach().cpu().numpy())
|
||||
|
||||
qwen.llm.hook_attention = None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
@ -187,4 +190,4 @@ if __name__ == "__main__":
|
|||
|
||||
dump_dk(md, meaning)
|
||||
|
||||
validate(qwen, val_loader)
|
||||
print("validate accuracy: " + str(validate(qwen, val_loader)))
|
||||
|
|
Loading…
Reference in New Issue