Reconfig train model and code.
This commit is contained in:
parent
9c75b8920d
commit
1c7635556f
|
@ -6,7 +6,7 @@ from typing import Dict, Tuple
|
||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
from node_tree import NodeTree
|
from dataset.node_tree import NodeTree
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
|
|
|
@ -39,7 +39,7 @@ if __name__ == "__main__":
|
||||||
config.vocab_size = 32
|
config.vocab_size = 32
|
||||||
config.hidden_size = 256 # 128 1024 2048 32
|
config.hidden_size = 256 # 128 1024 2048 32
|
||||||
config.intermediate_size = 512
|
config.intermediate_size = 512
|
||||||
config.num_hidden_layers = 4 # 6 12 24 3
|
config.num_hidden_layers = 3 # 6 12 24 3
|
||||||
config.num_attention_heads = 4 # 8 8 16
|
config.num_attention_heads = 4 # 8 8 16
|
||||||
|
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
|
|
Loading…
Reference in New Issue