from meaning_dataset import MeaningDataset, BatchGroupMeaningDataloader from special_dataset import SpecialDataset from torch.utils.data import random_split, DataLoader def InitDataset(config): train_batch_size = config.train_batch_size val_batch_size = config.val_batch_size num_proc = config.num_proc if config.dataset.name == "special": raw_dataset = SpecialDataset() train_dataset, val_dataset = random_split(raw_dataset, [0.95, 0.05]) train_dataloader = DataLoader( train_dataset, batch_size=train_batch_size, num_workers=num_proc, persistent_workers=True, shuffle=True, ) val_dataloader = DataLoader( val_dataset, batch_size=val_batch_size, num_workers=num_proc, persistent_workers=True, ) return train_dataloader, val_dataloader if config.dataset.name == "meaning": conf = config.dataset.meaning vocab = config.model_config.vocab_size start = vocab * (conf.level_ratio**conf.level) size = vocab * int((conf.level_ratio**conf.dataset_level)) raw_dataset = MeaningDataset(start, start + size, vocab, None, conf.level_ratio, conf.min_subitem) # print(raw_dataset.token_frequency()) raw_dataset.set_mask(conf.mask_level, conf.mask_idx) train_dataset, val_dataset = raw_dataset.split(0.9) train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size).dataloader( config.dataloader_works ) val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size).dataloader(config.dataloader_works) return train_dataloader, val_dataloader