Witllm/wit/dataset/dataset.py

63 lines
2.7 KiB
Python

from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader
from dataset.special_dataset import SpecialDataset
from torch.utils.data import random_split, DataLoader
import torch
import os
def InitDataset(config):
train_batch_size = config.train_batch_size
val_batch_size = config.val_batch_size
num_proc = config.num_proc
if config.dataset.name == "special":
raw_dataset = SpecialDataset()
train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05])
train_dataloader = DataLoader(
train_dataset,
batch_size=train_batch_size,
num_workers=num_proc,
persistent_workers=True,
shuffle=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=val_batch_size,
num_workers=num_proc,
persistent_workers=True,
)
return train_dataloader, val_dataloader
if config.dataset.name == "meaning":
conf = config.dataset.meaning
vocab = config.model_config.vocab_size
start = vocab * (conf.level_ratio**conf.level)
size = vocab * int((conf.level_ratio**conf.dataset_level))
path = "./data/"
trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{conf.level_ratio}_ms{conf.min_subitem}.pt"
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(trainfile) and os.path.exists(valfile):
print(f"INFO: Load dataset from {trainfile}")
train_dataset = torch.load(trainfile, weights_only=False)
train_dataset.set_mask(conf.mask_level, conf.mask_idx)
print(f"INFO: Load dataset from {valfile}")
val_dataset = torch.load(valfile, weights_only=False)
val_dataset.set_mask(conf.mask_level, conf.mask_idx)
print(f"INFO: Load dataset end")
else:
raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem)
raw_dataset.set_mask(conf.mask_level, conf.mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(train_dataset, trainfile)
torch.save(val_dataset, valfile)
print(f"INFO: Build and save dataset end")
train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader(
config.dataloader_works
)
val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works)
return train_dataloader, val_dataloader