62 lines
2.6 KiB
Python
62 lines
2.6 KiB
Python
from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
|
|
from dataset.special_dataset import SpecialDataset
|
|
from torch.utils.data import random_split, DataLoader
|
|
import torch
|
|
import os
|
|
|
|
|
|
def InitDataset(config):
|
|
|
|
train_batch_size = config.train_batch_size
|
|
val_batch_size = config.val_batch_size
|
|
num_proc = config.num_proc
|
|
if config.dataset.name == "special":
|
|
raw_dataset = SpecialDataset()
|
|
train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05])
|
|
train_dataloader = DataLoader(
|
|
train_dataset,
|
|
batch_size=train_batch_size,
|
|
num_workers=num_proc,
|
|
persistent_workers=True,
|
|
shuffle=True,
|
|
)
|
|
val_dataloader = DataLoader(
|
|
val_dataset,
|
|
batch_size=val_batch_size,
|
|
num_workers=num_proc,
|
|
persistent_workers=True,
|
|
)
|
|
return train_dataloader, val_dataloader
|
|
|
|
if config.dataset.name == "meaning":
|
|
conf = config.dataset.meaning
|
|
vocab = config.model_config.vocab_size
|
|
start = vocab * (conf.level_ratio**conf.level)
|
|
size = vocab * int((conf.level_ratio**conf.dataset_level))
|
|
|
|
path = "./data/"
|
|
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
|
|
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
|
|
if not os.path.exists(path):
|
|
os.mkdir(path)
|
|
if os.path.exists(trainfile) and os.path.exists(valfile):
|
|
print(f"INFO: Load dataset from {trainfile}")
|
|
print(f"INFO: Load dataset from {valfile}")
|
|
train_dataset = torch.load(trainfile)
|
|
val_dataset = torch.load(valfile)
|
|
print(f"INFO: Load dataset end")
|
|
else:
|
|
raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
|
|
print("INFO: raw_dataset.token_frequency" + raw_dataset.token_frequency())
|
|
raw_dataset.set_mask(conf.mask_level, conf.mask_idx)
|
|
train_dataset, val_dataset = raw_dataset.split(0.9)
|
|
torch.save(train_dataset, trainfile)
|
|
torch.save(val_dataset, valfile)
|
|
print(f"INFO: Build and save dataset end")
|
|
|
|
train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader(
|
|
config.dataloader_works
|
|
)
|
|
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
|
|
return train_dataloader, val_dataloader
|