Add config of stride with_tree end, Rename dataset to meaning.
This commit is contained in:
parent
1c7635556f
commit
71ab0bb57d
|
@ -39,11 +39,15 @@ class ModelConfig:
|
|||
class MeaningDatasetConfig:
|
||||
def __init__(self):
|
||||
self.start = 10000
|
||||
self.size = 4000
|
||||
self.end = 200000
|
||||
self.size = None
|
||||
self.min_subitem = 2
|
||||
self.max_subitem = 10
|
||||
self.val_mask_level = None
|
||||
self.val_mask_idx = None
|
||||
self.stride = 1
|
||||
self.with_tree = False
|
||||
self.seed = 42
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
|
|
|
@ -4,7 +4,7 @@ from model.light_module import LightModule
|
|||
from model.light_module import ModelRunner
|
||||
import numpy as np
|
||||
|
||||
import dataset.dataset as ds
|
||||
import meaning.dataset as ds
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
||||
from dataset.special_dataset import SpecialDataset
|
||||
from meaning.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
||||
from meaning.special_dataset import SpecialDataset
|
||||
from torch.utils.data import random_split, DataLoader
|
||||
import torch
|
||||
import os
|
||||
|
@ -33,10 +33,15 @@ def InitDataset(config):
|
|||
vocab = config.model_config.vocab_size
|
||||
start = c.start
|
||||
size = c.size
|
||||
end = c.end
|
||||
seed = c.seed
|
||||
|
||||
path = "./data/"
|
||||
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
||||
conf_name = (
|
||||
f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_stride{c.stride}_tree{c.with_tree}.pt"
|
||||
)
|
||||
trainfile = path + f"MeaningDataset_train" + conf_name
|
||||
valfile = path + f"MeaningDataset_val" + conf_name
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(trainfile) and os.path.exists(valfile):
|
||||
|
@ -48,7 +53,17 @@ def InitDataset(config):
|
|||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
print(f"INFO: Load dataset end")
|
||||
else:
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
|
||||
raw_dataset = MeaningDataset(
|
||||
start,
|
||||
end,
|
||||
vocab,
|
||||
size,
|
||||
c.max_subitem,
|
||||
c.min_subitem,
|
||||
stride=c.stride,
|
||||
with_tree=c.with_tree,
|
||||
seed=seed,
|
||||
)
|
||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||
torch.save(train_dataset, trainfile)
|
||||
|
@ -82,9 +97,14 @@ def InitValDataset(config):
|
|||
vocab = config.model_config.vocab_size
|
||||
start = c.start
|
||||
size = c.size
|
||||
end = c.end
|
||||
seed = c.seed
|
||||
|
||||
path = "./data/"
|
||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
||||
conf_name = (
|
||||
f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_stride{c.stride}_tree{c.with_tree}.pt"
|
||||
)
|
||||
valfile = path + f"MeaningDataset_val" + conf_name
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(valfile):
|
||||
|
@ -93,7 +113,17 @@ def InitValDataset(config):
|
|||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
print(f"INFO: Load dataset end")
|
||||
else:
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
|
||||
raw_dataset = MeaningDataset(
|
||||
start,
|
||||
end,
|
||||
vocab,
|
||||
size,
|
||||
c.max_subitem,
|
||||
c.min_subitem,
|
||||
stride=c.stride,
|
||||
with_tree=c.with_trees,
|
||||
seed=seed,
|
||||
)
|
||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||
torch.save(val_dataset, valfile)
|
|
@ -6,7 +6,7 @@ from typing import Dict, Tuple
|
|||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler
|
||||
from dataset.node_tree import NodeTree
|
||||
from meaning.node_tree import NodeTree
|
||||
|
||||
|
||||
class MeaningMap:
|
|
@ -11,7 +11,7 @@ sys.path.append("..")
|
|||
from tools import show
|
||||
|
||||
|
||||
import dataset.dataset as ds
|
||||
import meaning.dataset as ds
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from model.tokenization_qwen import QWenTokenizer
|
|||
import numpy as np
|
||||
|
||||
import configuration
|
||||
import dataset.dataset as ds
|
||||
import meaning.dataset as ds
|
||||
import dataset.node_tree as nt
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue