import os import datasets import torch import math import random from itertools import chain from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler class MeaningMap: # 16777216 1048576 8192 def __init__(self, size=1048576, vocab_size=4096, max_subitem=10): self.size = size self.vocab_size = vocab_size self.max_subitem = max_subitem file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) file_start = file + "_start" + ".npy" file_len = file + "_len" + ".npy" file_data = file + "_data" + ".npy" if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data): print("Load from disk cache: " + file) self.ms_data = np.load(file_data) self.ms_start = np.load(file_start) self.ms_len = np.load(file_len) else: print("Disk cache miss, build new one.") mm = np.empty((size, max_subitem), dtype=np.int32) index = np.arange(0, size) mm = np.random.random((size, max_subitem)) mask_zero = mm.copy() mask_zero[:, 0] = 0.0 mask_zero.sort(axis=1) thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) mask_zero = mask_zero > thre item_sum = mm.sum(axis=1) scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) mm = mm * scale mm[mask_zero] = 0 mm[:vocab_size, 0] = np.arange(0, vocab_size) mm[:vocab_size, 1:] = 0 mm = mm.astype(np.int32) ms = [] # meaning sequence ms_start = [] # meaning sequence start ms_len = [] # meaning sequence length index = 0 for i in range(self.vocab_size): ms.append([i]) ms_start.append(index) ms_len.append(1) index = index + 1 for i in range(self.vocab_size, size): m = mm[i] m = m[m > 0] ma = [] for newm in m.tolist(): ma = ma + ms[newm] ms.append(ma) ms_start.append(index) ms_len.append(len(ma)) index = index + len(ma) ms_data = list(chain(*ms)) np.save(file_data, np.array(ms_data).astype(np.int32)) np.save(file_start, np.array(ms_start).astype(np.int32)) np.save(file_len, np.array(ms_len).astype(np.int32)) self.ms_data = ms_data self.ms_start = ms_start self.ms_len = ms_len print("Disk cache build end.") def GetSequence(self, meaning): start = self.ms_start[meaning] len = self.ms_len[meaning] return self.ms_data[start : start + len] def MaxLength(self): return max(self.ms_len) class MeaningDataset(Dataset): def __init__( self, start=131072, end=1048576, size=32768, vocab_size=4096, max_subitem=10, min_seq_len=2, seed=42, data=None, length=None, ): if data != None and length != None: self.data = data self.length = length return np.random.seed(seed) mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576 self.data = [] self.length = [] meanings = np.random.randint(start, end, size=(size)) for m in meanings: sq = mm.GetSequence(m) if len(sq) >= min_seq_len: self.data.append(sq) self.length.append(len(sq)) def __len__(self): return len(self.data) def len(self): return len(self.data) def __getitem__(self, idx): output = {} data = torch.tensor(self.data[idx]).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) return output def GetBatch(self, index_list): data = [] for i in index_list: data.append(self.data[i]) output = {} data = torch.tensor(data).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) return output def Split(self, ratio): l = len(self.data) middle = int(l * ratio) d_shuffle = self.data.copy() l_shuffle = self.length.copy() md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle]) md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:]) return md1, md2 class BatchGroupMeaningDataloader(Dataset): def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True): self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last length = dataset.length unique, counts = np.unique(length, return_counts=True) gl = {} for u in unique: gl[u] = np.where(length == u)[0] lens = list(gl.keys()) gs = {} if shuffle: for k in gl.keys(): sl = gl[k].copy() np.random.shuffle(sl) gs[k] = sl else: for k in gl.keys(): sl = gl[k].copy() gs[k] = sl index = np.zeros((0, batch_size), dtype=np.int64) for l in lens: batch = len(gs[l]) // batch_size new = gs[l][0 : batch * batch_size].reshape(batch, batch_size) index = np.concatenate((index, new), axis=0) if shuffle: index_shuffle = np.arange(0, index.shape[0]) np.random.shuffle(index_shuffle) index = index[index_shuffle] self.index = index def __len__(self): return len(self.index) def __getitem__(self, idx): # print("get idx" + str(idx)) return self.dataset.GetBatch(self.index[idx]) if __name__ == "__main__": md = MeaningDataset(4096, 8100, size=1024) train, val = md.Split(0.95) dl = BatchGroupMeaningDataloader(train, 2) it = iter(dl) ne1 = next(it) ne2 = next(it) ne3 = next(it) dl = DataLoader( train, num_workers=1, persistent_workers=True, shuffle=False, ) it = iter(dl) ne1 = next(it) ne2 = next(it) ne3 = next(it) for i in range(10): daf = next(it)["input_ids"].numpy().tolist() print(daf)