Refine import code.

This commit is contained in:
Colin 2025-08-10 15:10:20 +08:00
parent b62444a9dc
commit b56bbb2952
4 changed files with 11 additions and 7 deletions

View File

@ -0,0 +1,2 @@
from .dataset import InitDataset
from .dataset import InitValDataset

View File

@ -1,5 +1,5 @@
from meaning.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
from meaning.special_dataset import SpecialDataset
from .meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
from .special_dataset import SpecialDataset
from torch.utils.data import random_split, DataLoader
import torch
import os

View File

@ -7,7 +7,11 @@ from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
if __name__ == "__main__":
from node_tree import NodeTree
else:
from .node_tree import NodeTree
class MeaningMap:

View File

@ -2,13 +2,11 @@ import pytorch_lightning as pl
import torch
from model.light_module import LightModule
from model.modeling_wit import ModelRunner
from model.tokenization_qwen import QWenTokenizer
import numpy as np
import configuration
import meaning.dataset as ds
import dataset.node_tree as nt
import meaning as m
if __name__ == "__main__":
@ -20,7 +18,7 @@ if __name__ == "__main__":
torch.manual_seed(conf.seed)
np.random.seed(conf.seed)
train_dataloader, val_dataloader = ds.InitDataset(conf)
train_dataloader, val_dataloader = m.InitDataset(conf)
loader = train_dataloader.dataset