Reset base config of model.
This commit is contained in:
parent
d48a5563aa
commit
a326df1bba
|
@ -66,7 +66,7 @@ def get_inference(dataset, seq):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
log_path = "log/bigger/version_2/"
|
log_path = "log/bigger/version_1/"
|
||||||
|
|
||||||
file = get_latest_file_safe(log_path + "/checkpoints")
|
file = get_latest_file_safe(log_path + "/checkpoints")
|
||||||
checkpoint_path = log_path + "checkpoints/" + file
|
checkpoint_path = log_path + "checkpoints/" + file
|
||||||
|
@ -90,6 +90,7 @@ if __name__ == "__main__":
|
||||||
# seq:991
|
# seq:991
|
||||||
# seq:995
|
# seq:995
|
||||||
meaning = 991
|
meaning = 991
|
||||||
|
meaning = 10991
|
||||||
|
|
||||||
node = map.get_nodetree(meaning)
|
node = map.get_nodetree(meaning)
|
||||||
node.print()
|
node.print()
|
||||||
|
|
10
wit/train.py
10
wit/train.py
|
@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||||
conf.use_tril_attention_mask = None
|
conf.use_tril_attention_mask = None
|
||||||
conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
conf.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
||||||
conf.train_batch_size = 32
|
conf.train_batch_size = 32
|
||||||
conf.val_batch_size = 2
|
conf.val_batch_size = 4
|
||||||
conf.num_proc = 8
|
conf.num_proc = 8
|
||||||
conf.max_epochs = 30
|
conf.max_epochs = 30
|
||||||
conf.strategy = "auto"
|
conf.strategy = "auto"
|
||||||
|
@ -30,14 +30,15 @@ if __name__ == "__main__":
|
||||||
conf.dataloader_works = 2
|
conf.dataloader_works = 2
|
||||||
|
|
||||||
conf.dataset.meaning.start = 10000
|
conf.dataset.meaning.start = 10000
|
||||||
conf.dataset.meaning.end = 100000
|
conf.dataset.meaning.end = 50000
|
||||||
conf.dataset.meaning.size = None
|
conf.dataset.meaning.size = None
|
||||||
|
conf.dataset.meaning.reserve_vocab = 1
|
||||||
conf.dataset.meaning.min_subitem = 2
|
conf.dataset.meaning.min_subitem = 2
|
||||||
conf.dataset.meaning.max_subitem = 6
|
conf.dataset.meaning.max_subitem = 6
|
||||||
conf.dataset.meaning.stride = 1
|
conf.dataset.meaning.stride = 1
|
||||||
conf.dataset.meaning.with_tree = False
|
conf.dataset.meaning.with_tree = False
|
||||||
conf.dataset.meaning.val_mask_level = [0, 1, 2]
|
conf.dataset.meaning.val_mask_level = [0]
|
||||||
conf.dataset.meaning.val_mask_idx = [0, 0, -1]
|
conf.dataset.meaning.val_mask_idx = [0]
|
||||||
|
|
||||||
config.vocab_size = 32
|
config.vocab_size = 32
|
||||||
config.hidden_size = 128 # 128 1024 2048 32
|
config.hidden_size = 128 # 128 1024 2048 32
|
||||||
|
@ -58,6 +59,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
logger = TBLogger("./log/", name=conf.name)
|
logger = TBLogger("./log/", name=conf.name)
|
||||||
logger.log_hyperparams(configuration.class_to_dict(conf))
|
logger.log_hyperparams(configuration.class_to_dict(conf))
|
||||||
|
configuration.class_to_file(conf, logger.log_dir + "/conf.pkl")
|
||||||
|
|
||||||
torch.set_float32_matmul_precision("medium")
|
torch.set_float32_matmul_precision("medium")
|
||||||
lit_trainer = pl.Trainer(
|
lit_trainer = pl.Trainer(
|
||||||
|
|
Loading…
Reference in New Issue