from dataset.meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader from dataset.special_dataset import SpecialDataset from torch.utils.data import random_split, DataLoader import torch import os def InitDataset(config): train_batch_size = config.train_batch_size val_batch_size = config.val_batch_size num_proc = config.num_proc if config.dataset.name == "special": raw_dataset = SpecialDataset() train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) train_dataloader = DataLoader( train_dataset, batch_size=train_batch_size, num_workers=num_proc, persistent_workers=True, shuffle=True, ) val_dataloader = DataLoader( val_dataset, batch_size=val_batch_size, num_workers=num_proc, persistent_workers=True, ) return train_dataloader, val_dataloader if config.dataset.name == "meaning": c = config.dataset.meaning vocab = config.model_config.vocab_size start = vocab * (c.level_ratio**c.level) size = vocab * int((c.level_ratio**c.dataset_level)) path = "./data/" trainfile = path + f"MeaningDataset_train_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt" valfile = path + f"MeaningDataset_val_v{size}_s{start}_s{size}_lr{c.level_ratio}_ms{c.min_subitem}.pt" if not os.path.exists(path): os.mkdir(path) if os.path.exists(trainfile) and os.path.exists(valfile): print(f"INFO: Load dataset from {trainfile}") train_dataset = torch.load(trainfile, weights_only=False) train_dataset.set_mask(c.mask_level, c.mask_idx) print(f"INFO: Load dataset from {valfile}") val_dataset = torch.load(valfile, weights_only=False) val_dataset.set_mask(c.mask_level, c.mask_idx) print(f"INFO: Load dataset end") else: raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem) raw_dataset.set_mask(c.mask_level, c.mask_idx) train_dataset, val_dataset = raw_dataset.split(0.9) torch.save(train_dataset, trainfile) torch.save(val_dataset, valfile) print(f"INFO: Build and save dataset end") train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader( config.dataloader_works ) val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) return train_dataloader, val_dataloader