Refine import code.
This commit is contained in:
parent
b62444a9dc
commit
b56bbb2952
|
@ -0,0 +1,2 @@
|
|||
from .dataset import InitDataset
|
||||
from .dataset import InitValDataset
|
|
@ -1,5 +1,5 @@
|
|||
from meaning.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
||||
from meaning.special_dataset import SpecialDataset
|
||||
from .meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
||||
from .special_dataset import SpecialDataset
|
||||
from torch.utils.data import random_split, DataLoader
|
||||
import torch
|
||||
import os
|
||||
|
|
|
@ -7,7 +7,11 @@ from typing import Dict, Tuple
|
|||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler
|
||||
from node_tree import NodeTree
|
||||
|
||||
if __name__ == "__main__":
|
||||
from node_tree import NodeTree
|
||||
else:
|
||||
from .node_tree import NodeTree
|
||||
|
||||
|
||||
class MeaningMap:
|
||||
|
|
|
@ -2,13 +2,11 @@ import pytorch_lightning as pl
|
|||
import torch
|
||||
|
||||
from model.light_module import LightModule
|
||||
from model.modeling_wit import ModelRunner
|
||||
from model.tokenization_qwen import QWenTokenizer
|
||||
import numpy as np
|
||||
|
||||
import configuration
|
||||
import meaning.dataset as ds
|
||||
import dataset.node_tree as nt
|
||||
import meaning as m
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
@ -20,7 +18,7 @@ if __name__ == "__main__":
|
|||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
|
||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||
train_dataloader, val_dataloader = m.InitDataset(conf)
|
||||
|
||||
loader = train_dataloader.dataset
|
||||
|
||||
|
|
Loading…
Reference in New Issue