Witllm/wit/meaning_dataset.py

144 lines
4.4 KiB
Python

import os
import datasets
import torch
import math
import random
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
class MeaningMap: # 16777216 1048576 8192
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
file_start = file + "_start" + ".npy"
file_len = file + "_len" + ".npy"
file_data = file + "_data" + ".npy"
if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
print("Load from disk cache: " + file)
self.ms_data = np.load(file_data)
self.ms_start = np.load(file_start)
self.ms_len = np.load(file_len)
return None
print("Disk cache miss, build new one.")
mm = np.empty((size, max_subitem), dtype=np.int32)
# total_level = int(math.log(size / vocab_size, max_subitem))
# start = [0]
# end = [vocab_size]
# shift = vocab_size
# for i in range(total_level):
# shift = end[-1]
# start.append(end[-1])
# end.append(shift * self.max_subitem)
# start.append(end[-1])
# end.append(size)
index = np.arange(0, size)
mm = np.random.random((size, max_subitem))
mask_zero = mm.copy()
mask_zero[:, 0] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
item_sum = mm.sum(axis=1)
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
mm = mm * scale
mm[mask_zero] = 0
mm[:vocab_size, 0] = np.arange(0, vocab_size)
mm[:vocab_size, 1:] = 0
mm = mm.astype(np.int32)
ms = [] # meaning sequence
ms_start = [] # meaning sequence start
ms_len = [] # meaning sequence length
index = 0
for i in range(self.vocab_size):
ms.append([i])
ms_start.append(index)
ms_len.append(1)
index = index + 1
for i in range(self.vocab_size, size):
m = mm[i]
m = m[m > 0]
ma = []
for newm in m.tolist():
ma = ma + ms[newm]
ms.append(ma)
ms_start.append(index)
ms_len.append(len(ma))
index = index + len(ma)
ms_data = list(chain(*ms))
np.save(file_data, np.array(ms_data).astype(np.int32))
np.save(file_start, np.array(ms_start).astype(np.int32))
np.save(file_len, np.array(ms_len).astype(np.int32))
self.ms_data = ms_data
self.ms_start = ms_start
self.ms_len = ms_len
print("Disk cache build end.")
def GetSequence(self, meaning):
start = self.ms_start[meaning]
len = self.ms_len[meaning]
return self.ms_data[start : start + len]
class MeaningDataset(Dataset):
def __init__(self, start=131072, end=1048576, size=32768, vocab_size=4096, max_subitem=10, seed=42):
self.seed = seed
np.random.seed(seed)
self.size = size
self.mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576
self.data = []
meanings = np.random.randint(start, end, size=(size))
for m in meanings:
sq = self.mm.GetSequence(m)
if len(sq) > 1:
self.data.append(sq)
left = size - len(self.data)
while True:
if left <= 0:
break
index = np.random.randint(start, end)
sq = self.mm.GetSequence(index)
if len(sq) > 1:
self.data.append(sq)
left = left - 1
def __len__(self):
return self.size
def __getitem__(self, idx):
output = {}
data = torch.tensor(self.data[idx]).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
return output
if __name__ == "__main__":
md = MeaningDataset(4096, 4100, size=32768)
it = iter(md)
for i in range(10):
daf = next(it)["input_ids"].numpy().tolist()
print(daf)