From b56bbb295260d6cedbb6f4114784194a1437d30f Mon Sep 17 00:00:00 2001 From: Colin <> Date: Sun, 10 Aug 2025 15:10:20 +0800 Subject: [PATCH] Refine import code. --- wit/meaning/__init__.py | 2 ++ wit/meaning/dataset.py | 4 ++-- wit/meaning/meaning_dataset.py | 6 +++++- wit/query_meaning_freq.py | 6 ++---- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/wit/meaning/__init__.py b/wit/meaning/__init__.py index e69de29..91f4e26 100644 --- a/wit/meaning/__init__.py +++ b/wit/meaning/__init__.py @@ -0,0 +1,2 @@ +from .dataset import InitDataset +from .dataset import InitValDataset diff --git a/wit/meaning/dataset.py b/wit/meaning/dataset.py index 6dde383..b196a81 100644 --- a/wit/meaning/dataset.py +++ b/wit/meaning/dataset.py @@ -1,5 +1,5 @@ -from meaning.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader -from meaning.special_dataset import SpecialDataset +from .meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader +from .special_dataset import SpecialDataset from torch.utils.data import random_split, DataLoader import torch import os diff --git a/wit/meaning/meaning_dataset.py b/wit/meaning/meaning_dataset.py index f80a09c..c5e043d 100644 --- a/wit/meaning/meaning_dataset.py +++ b/wit/meaning/meaning_dataset.py @@ -7,7 +7,11 @@ from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler -from node_tree import NodeTree + +if __name__ == "__main__": + from node_tree import NodeTree +else: + from .node_tree import NodeTree class MeaningMap: diff --git a/wit/query_meaning_freq.py b/wit/query_meaning_freq.py index 65021db..855f75a 100644 --- a/wit/query_meaning_freq.py +++ b/wit/query_meaning_freq.py @@ -2,13 +2,11 @@ import pytorch_lightning as pl import torch from model.light_module import LightModule -from model.modeling_wit import ModelRunner from model.tokenization_qwen import QWenTokenizer import numpy as np import configuration -import meaning.dataset as ds -import dataset.node_tree as nt +import meaning as m if __name__ == "__main__": @@ -20,7 +18,7 @@ if __name__ == "__main__": torch.manual_seed(conf.seed) np.random.seed(conf.seed) - train_dataloader, val_dataloader = ds.InitDataset(conf) + train_dataloader, val_dataloader = m.InitDataset(conf) loader = train_dataloader.dataset