Add config of stride with_tree end, Rename dataset to meaning.

This commit is contained in:
Colin 2025-08-10 13:17:04 +08:00
parent 1c7635556f
commit 71ab0bb57d
8 changed files with 46 additions and 12 deletions

View File

@ -39,11 +39,15 @@ class ModelConfig:
class MeaningDatasetConfig: class MeaningDatasetConfig:
def __init__(self): def __init__(self):
self.start = 10000 self.start = 10000
self.size = 4000 self.end = 200000
self.size = None
self.min_subitem = 2 self.min_subitem = 2
self.max_subitem = 10 self.max_subitem = 10
self.val_mask_level = None self.val_mask_level = None
self.val_mask_idx = None self.val_mask_idx = None
self.stride = 1
self.with_tree = False
self.seed = 42
class DatasetConfig: class DatasetConfig:

View File

@ -4,7 +4,7 @@ from model.light_module import LightModule
from model.light_module import ModelRunner from model.light_module import ModelRunner
import numpy as np import numpy as np
import dataset.dataset as ds import meaning.dataset as ds
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,5 +1,5 @@
from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader from meaning.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
from dataset.special_dataset import SpecialDataset from meaning.special_dataset import SpecialDataset
from torch.utils.data import random_split, DataLoader from torch.utils.data import random_split, DataLoader
import torch import torch
import os import os
@ -33,10 +33,15 @@ def InitDataset(config):
vocab = config.model_config.vocab_size vocab = config.model_config.vocab_size
start = c.start start = c.start
size = c.size size = c.size
end = c.end
seed = c.seed
path = "./data/" path = "./data/"
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt" conf_name = (
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt" f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_stride{c.stride}_tree{c.with_tree}.pt"
)
trainfile = path + f"MeaningDataset_train" + conf_name
valfile = path + f"MeaningDataset_val" + conf_name
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
if os.path.exists(trainfile) and os.path.exists(valfile): if os.path.exists(trainfile) and os.path.exists(valfile):
@ -48,7 +53,17 @@ def InitDataset(config):
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
print(f"INFO: Load dataset end") print(f"INFO: Load dataset end")
else: else:
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem) raw_dataset = MeaningDataset(
start,
end,
vocab,
size,
c.max_subitem,
c.min_subitem,
stride=c.stride,
with_tree=c.with_tree,
seed=seed,
)
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9) train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(train_dataset, trainfile) torch.save(train_dataset, trainfile)
@ -82,9 +97,14 @@ def InitValDataset(config):
vocab = config.model_config.vocab_size vocab = config.model_config.vocab_size
start = c.start start = c.start
size = c.size size = c.size
end = c.end
seed = c.seed
path = "./data/" path = "./data/"
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt" conf_name = (
f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_stride{c.stride}_tree{c.with_tree}.pt"
)
valfile = path + f"MeaningDataset_val" + conf_name
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
if os.path.exists(valfile): if os.path.exists(valfile):
@ -93,7 +113,17 @@ def InitValDataset(config):
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
print(f"INFO: Load dataset end") print(f"INFO: Load dataset end")
else: else:
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem) raw_dataset = MeaningDataset(
start,
end,
vocab,
size,
c.max_subitem,
c.min_subitem,
stride=c.stride,
with_tree=c.with_trees,
seed=seed,
)
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9) train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(val_dataset, valfile) torch.save(val_dataset, valfile)

View File

@ -6,7 +6,7 @@ from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np import numpy as np
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
from dataset.node_tree import NodeTree from meaning.node_tree import NodeTree
class MeaningMap: class MeaningMap:

View File

@ -11,7 +11,7 @@ sys.path.append("..")
from tools import show from tools import show
import dataset.dataset as ds import meaning.dataset as ds
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -7,7 +7,7 @@ from model.tokenization_qwen import QWenTokenizer
import numpy as np import numpy as np
import configuration import configuration
import dataset.dataset as ds import meaning.dataset as ds
import dataset.node_tree as nt import dataset.node_tree as nt
if __name__ == "__main__": if __name__ == "__main__":