import os import datasets import torch import math import random from itertools import chain from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np class MeaningMap: # 16777216 1048576 8192 def __init__(self, size=1048576, vocab_size=4096, max_subitem=10): self.size = size self.vocab_size = vocab_size self.max_subitem = max_subitem file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) file_start = file + "_start" + ".npy" file_len = file + "_len" + ".npy" file_data = file + "_data" + ".npy" if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data): print("Load from disk cache: " + file) self.ms_data = np.load(file_data) self.ms_start = np.load(file_start) self.ms_len = np.load(file_len) return None print("Disk cache miss, build new one.") mm = np.empty((size, max_subitem), dtype=np.int32) # total_level = int(math.log(size / vocab_size, max_subitem)) # start = [0] # end = [vocab_size] # shift = vocab_size # for i in range(total_level): # shift = end[-1] # start.append(end[-1]) # end.append(shift * self.max_subitem) # start.append(end[-1]) # end.append(size) index = np.arange(0, size) mm = np.random.random((size, max_subitem)) mask_zero = mm.copy() mask_zero[:, 0] = 0.0 mask_zero.sort(axis=1) thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) mask_zero = mask_zero > thre item_sum = mm.sum(axis=1) scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) mm = mm * scale mm[mask_zero] = 0 mm[:vocab_size, 0] = np.arange(0, vocab_size) mm[:vocab_size, 1:] = 0 mm = mm.astype(np.int32) ms = [] # meaning sequence ms_start = [] # meaning sequence start ms_len = [] # meaning sequence length index = 0 for i in range(self.vocab_size): ms.append([i]) ms_start.append(index) ms_len.append(1) index = index + 1 for i in range(self.vocab_size, size): m = mm[i] m = m[m > 0] ma = [] for newm in m.tolist(): ma = ma + ms[newm] ms.append(ma) ms_start.append(index) ms_len.append(len(ma)) index = index + len(ma) ms_data = list(chain(*ms)) np.save(file_data, np.array(ms_data).astype(np.int32)) np.save(file_start, np.array(ms_start).astype(np.int32)) np.save(file_len, np.array(ms_len).astype(np.int32)) self.ms_data = ms_data self.ms_start = ms_start self.ms_len = ms_len print("Disk cache build end.") def GetSequence(self, meaning): start = self.ms_start[meaning] len = self.ms_len[meaning] return self.ms_data[start : start + len] class MeaningDataset(Dataset): def __init__(self, start=131072, end=1048576, size=32768, vocab_size=4096, max_subitem=10, seed=42): self.seed = seed np.random.seed(seed) self.size = size self.mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576 self.data = [] meanings = np.random.randint(start, end, size=(size)) for m in meanings: sq = self.mm.GetSequence(m) if len(sq) > 1: self.data.append(sq) left = size - len(self.data) while True: if left <= 0: break index = np.random.randint(start, end) sq = self.mm.GetSequence(index) if len(sq) > 1: self.data.append(sq) left = left - 1 def __len__(self): return self.size def __getitem__(self, idx): output = {} data = torch.tensor(self.data[idx]).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) return output if __name__ == "__main__": md = MeaningDataset(4096, 4100, size=32768) it = iter(md) for i in range(10): daf = next(it)["input_ids"].numpy().tolist() print(daf)