From edb8d8103ab6bc14e659fd9407c6c1159c6ca885 Mon Sep 17 00:00:00 2001 From: Colin <> Date: Tue, 26 Aug 2025 15:47:03 +0800 Subject: [PATCH] Refine query. --- wit/query_block_output.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 003f40e..46d9f3a 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -78,10 +78,11 @@ def get_inference(dataset, meaning): node.print() -def validate(model, dataloader): +def validate(model, dataloader, mask_level=[0, 1], mask_index=[0, -1]): logger = TBLogger("./log/", name=conf.name) logger.log_hyperparams(configuration.class_to_dict(conf)) + model.eval() lit_trainer = pl.Trainer( accelerator="cuda", @@ -91,10 +92,10 @@ def validate(model, dataloader): max_epochs=conf.max_epochs, ) - dataloader.dataset.meaning_dataset.set_mask([0, 1], [0, -1]) + dataloader.dataset.meaning_dataset.set_mask(mask_level, mask_index) - val = lit_trainer.validate(model, dataloaders=dataloader) - print(val) + val = lit_trainer.validate(model, dataloaders=dataloader, verbose=False) + return val[0]["accuracy"] def get_contain_meaning(dataset, meaning): @@ -151,6 +152,8 @@ def dump_dk(dataset, meaning): # print(sorted_logits.detach().cpu().numpy()) # print(sorted_indices.detach().cpu().numpy()) + qwen.llm.hook_attention = None + if __name__ == "__main__": @@ -187,4 +190,4 @@ if __name__ == "__main__": dump_dk(md, meaning) - validate(qwen, val_loader) + print("validate accuracy: " + str(validate(qwen, val_loader)))