diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 5f31c98..6f964d7 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -66,7 +66,7 @@ def get_inference(dataset, seq): if __name__ == "__main__": - log_path = "log/bigger/version_2/" + log_path = "log/bigger/version_1/" file = get_latest_file_safe(log_path + "/checkpoints") checkpoint_path = log_path + "checkpoints/" + file @@ -90,6 +90,7 @@ if __name__ == "__main__": # seq:991 # seq:995 meaning = 991 + meaning = 10991 node = map.get_nodetree(meaning) node.print() diff --git a/wit/train.py b/wit/train.py index 5fab291..5c5b70e 100644 --- a/wit/train.py +++ b/wit/train.py @@ -21,7 +21,7 @@ if __name__ == "__main__": conf.use_tril_attention_mask = None conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true" conf.train_batch_size = 32 - conf.val_batch_size = 2 + conf.val_batch_size = 4 conf.num_proc = 8 conf.max_epochs = 30 conf.strategy = "auto" @@ -30,14 +30,15 @@ if __name__ == "__main__": conf.dataloader_works = 2 conf.dataset.meaning.start = 10000 - conf.dataset.meaning.end = 100000 + conf.dataset.meaning.end = 50000 conf.dataset.meaning.size = None + conf.dataset.meaning.reserve_vocab = 1 conf.dataset.meaning.min_subitem = 2 conf.dataset.meaning.max_subitem = 6 conf.dataset.meaning.stride = 1 conf.dataset.meaning.with_tree = False - conf.dataset.meaning.val_mask_level = [0, 1, 2] - conf.dataset.meaning.val_mask_idx = [0, 0, -1] + conf.dataset.meaning.val_mask_level = [0] + conf.dataset.meaning.val_mask_idx = [0] config.vocab_size = 32 config.hidden_size = 128 # 128 1024 2048 32 @@ -58,6 +59,7 @@ if __name__ == "__main__": logger = TBLogger("./log/", name=conf.name) logger.log_hyperparams(configuration.class_to_dict(conf)) + configuration.class_to_file(conf, logger.log_dir + "/conf.pkl") torch.set_float32_matmul_precision("medium") lit_trainer = pl.Trainer(