Witllm/wit/meaning_dataset.py

236 lines
7.2 KiB
Python

import os
import datasets
import torch
import math
import random
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
class MeaningMap: # 16777216 1048576 8192
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
file = path + file
file_start = file + "_start" + ".npy"
file_len = file + "_len" + ".npy"
file_data = file + "_data" + ".npy"
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
print("Load from disk cache: " + file)
self.ms_data = np.load(file_data)
self.ms_start = np.load(file_start)
self.ms_len = np.load(file_len)
print("Load end")
else:
print("Disk cache miss, build new one.")
mm = np.empty((size, max_subitem), dtype=np.int32)
index = np.arange(0, size)
mm = np.random.random((size, max_subitem))
mask_zero = mm.copy()
mask_zero[:, 0] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
item_sum = mm.sum(axis=1)
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
mm = mm * scale
mm[mask_zero] = 0
mm[:vocab_size, 0] = np.arange(0, vocab_size)
mm[:vocab_size, 1:] = 0
mm = mm.astype(np.int32)
ms = [] # meaning sequence
ms_start = [] # meaning sequence start
ms_len = [] # meaning sequence length
index = 0
for i in range(self.vocab_size):
ms.append([i])
ms_start.append(index)
ms_len.append(1)
index = index + 1
for i in range(self.vocab_size, size):
m = mm[i]
m = m[m > 0]
ma = []
for newm in m.tolist():
ma = ma + ms[newm]
ms.append(ma)
ms_start.append(index)
ms_len.append(len(ma))
index = index + len(ma)
ms_data = list(chain(*ms))
np.save(file_data, np.array(ms_data).astype(np.int32))
np.save(file_start, np.array(ms_start).astype(np.int32))
np.save(file_len, np.array(ms_len).astype(np.int32))
self.ms_data = ms_data
self.ms_start = ms_start
self.ms_len = ms_len
print("Disk cache build end.")
def GetSequence(self, meaning):
start = self.ms_start[meaning]
len = self.ms_len[meaning]
return self.ms_data[start : start + len]
def MaxLength(self):
return max(self.ms_len)
class MeaningDataset(Dataset):
def __init__(
self,
start=131072,
end=1048576,
size=32768,
vocab_size=4096,
max_subitem=10,
min_seq_len=2,
seed=42,
data=None,
length=None,
):
if data != None and length != None:
self.data = data
self.length = length
return
np.random.seed(seed)
mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576
self.data = []
self.length = []
meanings = np.random.randint(start, end, size=(size))
for m in meanings:
sq = mm.GetSequence(m)
if len(sq) >= min_seq_len:
self.data.append(sq)
self.length.append(len(sq))
unique, counts = np.unique(self.length, return_counts=True)
print("MeaningDataset size: " + str(len(self.length)))
print("MeaningDataset max sequence length: " + str(max(unique)))
print("MeaningDataset most popular sequence length: " + str(unique[np.argmax(counts)]))
def __len__(self):
return len(self.data)
def len(self):
return len(self.data)
def __getitem__(self, idx):
output = {}
data = torch.tensor(self.data[idx]).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
return output
def GetBatch(self, index_list): # must equal sequence length
data = [self.data[i] for i in index_list]
output = {}
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
return output
def Split(self, ratio):
l = len(self.data)
middle = int(l * ratio)
d_shuffle = self.data.copy()
l_shuffle = self.length.copy()
md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle])
md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:])
return md1, md2
class BatchGroupMeaningDataloader(Dataset):
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
length = dataset.length
unique, counts = np.unique(length, return_counts=True)
gl = {}
for u in unique:
gl[u] = np.where(length == u)[0]
lens = list(gl.keys())
gs = {}
if shuffle:
for k in gl.keys():
sl = gl[k].copy()
np.random.shuffle(sl)
gs[k] = sl
else:
for k in gl.keys():
sl = gl[k].copy()
gs[k] = sl
index = np.zeros((0, batch_size), dtype=np.int64)
for l in lens:
batch = len(gs[l]) // batch_size
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
index = np.concatenate((index, new), axis=0)
if shuffle:
index_shuffle = np.arange(0, index.shape[0])
np.random.shuffle(index_shuffle)
index = index[index_shuffle]
self.indexBatch = index
def __len__(self):
return len(self.indexBatch)
def __getitem__(self, idx):
# print("get idx" + str(idx))
return self.dataset.GetBatch(self.indexBatch[idx])
if __name__ == "__main__":
md = MeaningDataset(4096, 8100, size=1024)
train, val = md.Split(0.95)
dl = BatchGroupMeaningDataloader(train, 32)
length = len(dl)
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
dl = DataLoader(
train,
num_workers=1,
persistent_workers=True,
shuffle=False,
)
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
for i in range(10):
daf = next(it)["input_ids"].numpy().tolist()
print(daf)