import os import datasets import torch import math import random from itertools import chain from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler class MeaningMap: def __init__(self, size=1048576, vocab_size=4096, max_subitem=10): self.size = size self.vocab_size = vocab_size self.max_subitem = max_subitem path = "./data/" file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) file = path + file file_map = file + "_map" + ".npy" file_start = file + "_start" + ".npy" file_len = file + "_len" + ".npy" file_data = file + "_data" + ".npy" if not os.path.exists(path): os.mkdir(path) if ( os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data) and os.path.exists(file_map) ): print("Load from disk cache: " + file) self.ms_map = np.load(file_map) self.ms_data = np.load(file_data) self.ms_start = np.load(file_start) self.ms_len = np.load(file_len) print("Load end") else: print("Disk cache miss, build new one.") mm = np.empty((size, max_subitem), dtype=np.int32) index = np.arange(0, size) mm = np.random.random((size, max_subitem)) mask_zero = mm.copy() mask_zero[:, 0] = 0.0 mask_zero.sort(axis=1) thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) mask_zero = mask_zero > thre item_sum = mm.sum(axis=1) scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) mm = mm * scale mm[mask_zero] = 0 mm[:vocab_size, 0] = np.arange(0, vocab_size) mm[:vocab_size, 1:] = 0 mm = mm.astype(np.int32) ms = [] # meaning sequence ms_start = [] # meaning sequence start ms_len = [] # meaning sequence length index = 0 for i in range(self.vocab_size): ms.append([i]) ms_start.append(index) ms_len.append(1) index = index + 1 for i in range(self.vocab_size, size): m = mm[i] m = m[m > 0] ma = [] for newm in m.tolist(): ma = ma + ms[newm] ms.append(ma) ms_start.append(index) ms_len.append(len(ma)) index = index + len(ma) ms_data = list(chain(*ms)) np.save(file_map, np.array(mm).astype(np.int32)) np.save(file_data, np.array(ms_data).astype(np.int32)) np.save(file_start, np.array(ms_start).astype(np.int32)) np.save(file_len, np.array(ms_len).astype(np.int32)) self.ms_map = mm self.ms_data = ms_data self.ms_start = ms_start self.ms_len = ms_len print("Disk cache build end.") def get_sequence(self, meaning): start = self.ms_start[meaning] len = self.ms_len[meaning] return self.ms_data[start : start + len] def get_mapping(self, meaning): mapping = {} ms = self.ms_map[meaning] for m in ms[ms > 0].tolist(): mapping[m] = self.get_mapping(m) if m >= self.vocab_size else m return mapping def max_length(self): return max(self.ms_len) class MeaningDataset(Dataset): def __init__( self, start=131072, end=1048576, size=32768, vocab_size=4096, max_subitem=10, min_seq_len=2, seed=42, data=None, length=None, mapping=None, ): if data != None and length != None and mapping != None: self.data = data self.length = length self.mapping = mapping return np.random.seed(seed) mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576 self.mapping = [] self.data = [] self.length = [] meanings = np.random.randint(start, end, size=(size)) for m in meanings: sq = mm.get_sequence(m) if len(sq) >= min_seq_len: self.mapping.append({m: mm.get_mapping(m)}) self.data.append(sq) self.length.append(len(sq)) unique, counts = np.unique(self.length, return_counts=True) print("----------------------------------------------------------------") print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start)) print("MeaningDataset size:" + str(len(self.length))) print("MeaningDataset max sequence length:" + str(max(unique))) print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)])) print("----------------------------------------------------------------") def __len__(self): return len(self.data) def len(self): return len(self.data) def __getitem__(self, idx): output = {} data = torch.tensor(self.data[idx]).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) return output def get_batch(self, index_list): # must equal sequence length data = [self.data[i] for i in index_list] output = {} data = torch.tensor(np.stack(data, axis=0)).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) return output def get_token_batch(self, index_list): # must equal sequence length return [self.data[i] for i in index_list] def print_token_batch(self, index_list): # must equal sequence length data = [self.data[i] for i in index_list] output = {} data = torch.tensor(np.stack(data, axis=0)).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) return output def get_mapping_batch(self, index_list): return [self.mapping[i] for i in index_list] def __get_mapping_str__(map, prefix): if isinstance(map, dict): base = "" for key, value in map.items(): base += prefix + str(key) + "\n" base += MeaningDataset.__get_mapping_str__(value, prefix + " ") return base else: return "" def print_mapping_batch(self, index_list): tokens = self.get_token_batch(index_list) map = self.get_mapping_batch(index_list) s = "--------------------------------------------------------\n" for i, m in enumerate(map): s += str(tokens[i]) + "\n" s += MeaningDataset.__get_mapping_str__(m, "") s += "--------------------------------------------------------\n" return s def split(self, ratio): l = len(self.data) middle = int(l * ratio) d_shuffle = self.data.copy() l_shuffle = self.length.copy() m_shuffle = self.mapping.copy() md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle], mapping=m_shuffle[:middle]) md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:], mapping=m_shuffle[middle:]) return md1, md2 class BatchGroupMeaningDataloader(Dataset): def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True): self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last length = dataset.length unique, counts = np.unique(length, return_counts=True) gl = {} for u in unique: gl[u] = np.where(length == u)[0] lens = list(gl.keys()) gs = {} if shuffle: for k in gl.keys(): sl = gl[k].copy() np.random.shuffle(sl) gs[k] = sl else: for k in gl.keys(): sl = gl[k].copy() gs[k] = sl index = np.zeros((0, batch_size), dtype=np.int64) for l in lens: batch = len(gs[l]) // batch_size new = gs[l][0 : batch * batch_size].reshape(batch, batch_size) index = np.concatenate((index, new), axis=0) if shuffle: index_shuffle = np.arange(0, index.shape[0]) np.random.shuffle(index_shuffle) index = index[index_shuffle] self.indexBatch = index print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index))) print("Dataloader total:" + str(len(length)) + " drop:" + str(len(length) - len(index) * batch_size)) def __len__(self): return len(self.indexBatch) def __getitem__(self, idx): return self.dataset.get_batch(self.indexBatch[idx]) def mapping(self, idx): return self.dataset.get_mapping_batch(self.indexBatch[idx]) def print_mapping(self, idx): return self.dataset.print_mapping_batch(self.indexBatch[idx]) if __name__ == "__main__": md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024) train, val = md.split(0.95) dl = BatchGroupMeaningDataloader(train, 2) length = len(dl) it = iter(dl) ne1 = next(it) ne2 = next(it) ne3 = next(it) map1 = dl.mapping(0) map2 = dl.mapping(1) print(dl.print_mapping(0)) dl = DataLoader( train, num_workers=1, persistent_workers=True, shuffle=False, ) it = iter(dl) ne1 = next(it) ne2 = next(it) ne3 = next(it) for i in range(10): daf = next(it)["input_ids"].numpy().tolist() print(daf)