Add config of stride with_tree end, Rename dataset to meaning.
This commit is contained in:
parent
1c7635556f
commit
71ab0bb57d
|
@ -39,11 +39,15 @@ class ModelConfig:
|
||||||
class MeaningDatasetConfig:
|
class MeaningDatasetConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.start = 10000
|
self.start = 10000
|
||||||
self.size = 4000
|
self.end = 200000
|
||||||
|
self.size = None
|
||||||
self.min_subitem = 2
|
self.min_subitem = 2
|
||||||
self.max_subitem = 10
|
self.max_subitem = 10
|
||||||
self.val_mask_level = None
|
self.val_mask_level = None
|
||||||
self.val_mask_idx = None
|
self.val_mask_idx = None
|
||||||
|
self.stride = 1
|
||||||
|
self.with_tree = False
|
||||||
|
self.seed = 42
|
||||||
|
|
||||||
|
|
||||||
class DatasetConfig:
|
class DatasetConfig:
|
||||||
|
|
|
@ -4,7 +4,7 @@ from model.light_module import LightModule
|
||||||
from model.light_module import ModelRunner
|
from model.light_module import ModelRunner
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import dataset.dataset as ds
|
import meaning.dataset as ds
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
from meaning.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
||||||
from dataset.special_dataset import SpecialDataset
|
from meaning.special_dataset import SpecialDataset
|
||||||
from torch.utils.data import random_split, DataLoader
|
from torch.utils.data import random_split, DataLoader
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
|
@ -33,10 +33,15 @@ def InitDataset(config):
|
||||||
vocab = config.model_config.vocab_size
|
vocab = config.model_config.vocab_size
|
||||||
start = c.start
|
start = c.start
|
||||||
size = c.size
|
size = c.size
|
||||||
|
end = c.end
|
||||||
|
seed = c.seed
|
||||||
|
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
conf_name = (
|
||||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_stride{c.stride}_tree{c.with_tree}.pt"
|
||||||
|
)
|
||||||
|
trainfile = path + f"MeaningDataset_train" + conf_name
|
||||||
|
valfile = path + f"MeaningDataset_val" + conf_name
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
if os.path.exists(trainfile) and os.path.exists(valfile):
|
if os.path.exists(trainfile) and os.path.exists(valfile):
|
||||||
|
@ -48,7 +53,17 @@ def InitDataset(config):
|
||||||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
print(f"INFO: Load dataset end")
|
print(f"INFO: Load dataset end")
|
||||||
else:
|
else:
|
||||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
|
raw_dataset = MeaningDataset(
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
vocab,
|
||||||
|
size,
|
||||||
|
c.max_subitem,
|
||||||
|
c.min_subitem,
|
||||||
|
stride=c.stride,
|
||||||
|
with_tree=c.with_tree,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||||
torch.save(train_dataset, trainfile)
|
torch.save(train_dataset, trainfile)
|
||||||
|
@ -82,9 +97,14 @@ def InitValDataset(config):
|
||||||
vocab = config.model_config.vocab_size
|
vocab = config.model_config.vocab_size
|
||||||
start = c.start
|
start = c.start
|
||||||
size = c.size
|
size = c.size
|
||||||
|
end = c.end
|
||||||
|
seed = c.seed
|
||||||
|
|
||||||
path = "./data/"
|
path = "./data/"
|
||||||
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}.pt"
|
conf_name = (
|
||||||
|
f"_s{start}_e{end}_s{size}_ms{c.min_subitem}_maxs{c.max_subitem}_stride{c.stride}_tree{c.with_tree}.pt"
|
||||||
|
)
|
||||||
|
valfile = path + f"MeaningDataset_val" + conf_name
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
if os.path.exists(valfile):
|
if os.path.exists(valfile):
|
||||||
|
@ -93,7 +113,17 @@ def InitValDataset(config):
|
||||||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
print(f"INFO: Load dataset end")
|
print(f"INFO: Load dataset end")
|
||||||
else:
|
else:
|
||||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.max_subitem, c.min_subitem)
|
raw_dataset = MeaningDataset(
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
vocab,
|
||||||
|
size,
|
||||||
|
c.max_subitem,
|
||||||
|
c.min_subitem,
|
||||||
|
stride=c.stride,
|
||||||
|
with_tree=c.with_trees,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||||
torch.save(val_dataset, valfile)
|
torch.save(val_dataset, valfile)
|
|
@ -6,7 +6,7 @@ from typing import Dict, Tuple
|
||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import BatchSampler
|
from torch.utils.data import BatchSampler
|
||||||
from dataset.node_tree import NodeTree
|
from meaning.node_tree import NodeTree
|
||||||
|
|
||||||
|
|
||||||
class MeaningMap:
|
class MeaningMap:
|
|
@ -11,7 +11,7 @@ sys.path.append("..")
|
||||||
from tools import show
|
from tools import show
|
||||||
|
|
||||||
|
|
||||||
import dataset.dataset as ds
|
import meaning.dataset as ds
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from model.tokenization_qwen import QWenTokenizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import configuration
|
import configuration
|
||||||
import dataset.dataset as ds
|
import meaning.dataset as ds
|
||||||
import dataset.node_tree as nt
|
import dataset.node_tree as nt
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue