Refine import code.
This commit is contained in:
parent
b62444a9dc
commit
b56bbb2952
|
@ -0,0 +1,2 @@
|
||||||
|
from .dataset import InitDataset
|
||||||
|
from .dataset import InitValDataset
|
|
@ -1,5 +1,5 @@
|
||||||
from meaning.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
from .meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
||||||
from meaning.special_dataset import SpecialDataset
|
from .special_dataset import SpecialDataset
|
||||||
from torch.utils.data import random_split, DataLoader
|
from torch.utils.data import random_split, DataLoader
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
|
|
@ -7,7 +7,11 @@ from typing import Dict, Tuple
|
||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
from node_tree import NodeTree
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from node_tree import NodeTree
|
||||||
|
else:
|
||||||
|
from .node_tree import NodeTree
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
||||||
|
|
|
@ -2,13 +2,11 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from model.light_module import LightModule
|
from model.light_module import LightModule
|
||||||
from model.modeling_wit import ModelRunner
|
|
||||||
from model.tokenization_qwen import QWenTokenizer
|
from model.tokenization_qwen import QWenTokenizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import configuration
|
import configuration
|
||||||
import meaning.dataset as ds
|
import meaning as m
|
||||||
import dataset.node_tree as nt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
@ -20,7 +18,7 @@ if __name__ == "__main__":
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
np.random.seed(conf.seed)
|
np.random.seed(conf.seed)
|
||||||
|
|
||||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
train_dataloader, val_dataloader = m.InitDataset(conf)
|
||||||
|
|
||||||
loader = train_dataloader.dataset
|
loader = train_dataloader.dataset
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue