diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index f637b62..df0cfcc 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -6,7 +6,7 @@ from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler -from node_tree import NodeTree +from dataset.node_tree import NodeTree class MeaningMap: diff --git a/wit/train.py b/wit/train.py index 9021e32..48a6c99 100644 --- a/wit/train.py +++ b/wit/train.py @@ -39,7 +39,7 @@ if __name__ == "__main__": config.vocab_size = 32 config.hidden_size = 256 # 128 1024 2048 32 config.intermediate_size = 512 - config.num_hidden_layers = 4 # 6 12 24 3 + config.num_hidden_layers = 3 # 6 12 24 3 config.num_attention_heads = 4 # 8 8 16 torch.manual_seed(conf.seed)