From 71ab0bb57d36cae041234d695dc01278c20911a7 Mon Sep 17 00:00:00 2001 From: Colin <> Date: Sun, 10 Aug 2025 13:17:04 +0800 Subject: [PATCH] Add config of stride with_tree end, Rename dataset to meaning. --- wit/configuration.py | 6 ++- wit/inference.py | 2 +- wit/{dataset => meaning}/dataset.py | 44 +++++++++++++++++---- wit/{dataset => meaning}/meaning_dataset.py | 2 +- wit/{dataset => meaning}/node_tree.py | 0 wit/{dataset => meaning}/special_dataset.py | 0 wit/query_block_output.py | 2 +- wit/query_meaning_freq.py | 2 +- 8 files changed, 46 insertions(+), 12 deletions(-) rename wit/{dataset => meaning}/dataset.py (73%) rename wit/{dataset => meaning}/meaning_dataset.py (99%) rename wit/{dataset => meaning}/node_tree.py (100%) rename wit/{dataset => meaning}/special_dataset.py (100%) diff --git a/wit/configuration.py b/wit/configuration.py index d512c36..d8c419f 100644 --- a/wit/configuration.py +++ b/wit/configuration.py @@ -39,11 +39,15 @@ class ModelConfig: class MeaningDatasetConfig: def __init__(self): self.start = 10000 - self.size = 4000 + self.end = 200000 + self.size = None self.min_subitem = 2 self.max_subitem = 10 self.val_mask_level = None self.val_mask_idx = None + self.stride = 1 + self.with_tree = False + self.seed = 42 class DatasetConfig: diff --git a/wit/inference.py b/wit/inference.py index 9fa265e..80ef079 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -4,7 +4,7 @@ from model.light_module import LightModule from model.light_module import ModelRunner import numpy as np -import dataset.dataset as ds +import meaning.dataset as ds if __name__ == "__main__": diff --git a/wit/dataset/dataset.py b/wit/meaning/dataset.py similarity index 73% rename from wit/dataset/dataset.py rename to wit/meaning/dataset.py index edfcf9e..6dde383 100644 --- a/wit/dataset/dataset.py +++ b/wit/meaning/dataset.py @@ -1,5 +1,5 @@ -from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader -from dataset.special_dataset import SpecialDataset +from meaning.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader +from meaning.special_dataset import SpecialDataset from torch.utils.data import random_split, DataLoader import torch import os @@ -33,10 +33,15 @@ def InitDataset(config): vocab = config.model_config.vocab_size start = c.start size = c.size + end = c.end + seed = c.seed path = "./data/" - trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt" - valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt" + conf_name = ( + f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_stride{c.stride}_tree{c.with_tree}.pt" + ) + trainfile = path + f"MeaningDataset_train" + conf_name + valfile = path + f"MeaningDataset_val" + conf_name if not os.path.exists(path): os.mkdir(path) if os.path.exists(trainfile) and os.path.exists(valfile): @@ -48,7 +53,17 @@ def InitDataset(config): val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) print(f"INFO: Load dataset end") else: - raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem) + raw_dataset = MeaningDataset( + start, + end, + vocab, + size, + c.max_subitem, + c.min_subitem, + stride=c.stride, + with_tree=c.with_tree, + seed=seed, + ) raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) train_dataset, val_dataset = raw_dataset.split(0.9) torch.save(train_dataset, trainfile) @@ -82,9 +97,14 @@ def InitValDataset(config): vocab = config.model_config.vocab_size start = c.start size = c.size + end = c.end + seed = c.seed path = "./data/" - valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt" + conf_name = ( + f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_stride{c.stride}_tree{c.with_tree}.pt" + ) + valfile = path + f"MeaningDataset_val" + conf_name if not os.path.exists(path): os.mkdir(path) if os.path.exists(valfile): @@ -93,7 +113,17 @@ def InitValDataset(config): val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) print(f"INFO: Load dataset end") else: - raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem) + raw_dataset = MeaningDataset( + start, + end, + vocab, + size, + c.max_subitem, + c.min_subitem, + stride=c.stride, + with_tree=c.with_trees, + seed=seed, + ) raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) train_dataset, val_dataset = raw_dataset.split(0.9) torch.save(val_dataset, valfile) diff --git a/wit/dataset/meaning_dataset.py b/wit/meaning/meaning_dataset.py similarity index 99% rename from wit/dataset/meaning_dataset.py rename to wit/meaning/meaning_dataset.py index df0cfcc..fbb3c0d 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/meaning/meaning_dataset.py @@ -6,7 +6,7 @@ from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler -from dataset.node_tree import NodeTree +from meaning.node_tree import NodeTree class MeaningMap: diff --git a/wit/dataset/node_tree.py b/wit/meaning/node_tree.py similarity index 100% rename from wit/dataset/node_tree.py rename to wit/meaning/node_tree.py diff --git a/wit/dataset/special_dataset.py b/wit/meaning/special_dataset.py similarity index 100% rename from wit/dataset/special_dataset.py rename to wit/meaning/special_dataset.py diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 27157ca..379427c 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -11,7 +11,7 @@ sys.path.append("..") from tools import show -import dataset.dataset as ds +import meaning.dataset as ds if __name__ == "__main__": diff --git a/wit/query_meaning_freq.py b/wit/query_meaning_freq.py index 87087c1..65021db 100644 --- a/wit/query_meaning_freq.py +++ b/wit/query_meaning_freq.py @@ -7,7 +7,7 @@ from model.tokenization_qwen import QWenTokenizer import numpy as np import configuration -import dataset.dataset as ds +import meaning.dataset as ds import dataset.node_tree as nt if __name__ == "__main__":