Witllm/wit/meaning_dataset.py

229 lines
6.8 KiB
Python
Raw Normal View History

2024-03-13 19:41:02 +08:00
import os
import datasets
import torch
import math
import random
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
2024-03-18 11:43:41 +08:00
from torch.utils.data import BatchSampler
2024-03-13 19:41:02 +08:00
class MeaningMap: # 16777216 1048576 8192
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
2024-03-26 11:32:02 +08:00
path = "./data/"
2024-03-13 19:41:02 +08:00
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
2024-03-26 11:32:02 +08:00
file = path + file
2024-03-13 19:41:02 +08:00
file_start = file + "_start" + ".npy"
file_len = file + "_len" + ".npy"
file_data = file + "_data" + ".npy"
2024-03-26 11:32:02 +08:00
if not os.path.exists(path):
os.mkdir(path)
2024-03-13 19:41:02 +08:00
if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
print("Load from disk cache: " + file)
self.ms_data = np.load(file_data)
self.ms_start = np.load(file_start)
self.ms_len = np.load(file_len)
2024-03-18 11:43:41 +08:00
else:
print("Disk cache miss, build new one.")
mm = np.empty((size, max_subitem), dtype=np.int32)
index = np.arange(0, size)
mm = np.random.random((size, max_subitem))
mask_zero = mm.copy()
mask_zero[:, 0] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
item_sum = mm.sum(axis=1)
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
mm = mm * scale
mm[mask_zero] = 0
mm[:vocab_size, 0] = np.arange(0, vocab_size)
mm[:vocab_size, 1:] = 0
mm = mm.astype(np.int32)
ms = [] # meaning sequence
ms_start = [] # meaning sequence start
ms_len = [] # meaning sequence length
index = 0
for i in range(self.vocab_size):
ms.append([i])
ms_start.append(index)
ms_len.append(1)
index = index + 1
for i in range(self.vocab_size, size):
m = mm[i]
m = m[m > 0]
ma = []
for newm in m.tolist():
ma = ma + ms[newm]
ms.append(ma)
ms_start.append(index)
ms_len.append(len(ma))
index = index + len(ma)
ms_data = list(chain(*ms))
np.save(file_data, np.array(ms_data).astype(np.int32))
np.save(file_start, np.array(ms_start).astype(np.int32))
np.save(file_len, np.array(ms_len).astype(np.int32))
self.ms_data = ms_data
self.ms_start = ms_start
self.ms_len = ms_len
print("Disk cache build end.")
2024-03-13 19:41:02 +08:00
def GetSequence(self, meaning):
start = self.ms_start[meaning]
len = self.ms_len[meaning]
return self.ms_data[start : start + len]
2024-03-18 11:43:41 +08:00
def MaxLength(self):
return max(self.ms_len)
2024-03-13 19:41:02 +08:00
class MeaningDataset(Dataset):
2024-03-18 11:43:41 +08:00
def __init__(
self,
start=131072,
end=1048576,
size=32768,
vocab_size=4096,
max_subitem=10,
min_seq_len=2,
seed=42,
data=None,
length=None,
):
if data != None and length != None:
self.data = data
self.length = length
return
2024-03-13 19:41:02 +08:00
np.random.seed(seed)
2024-03-18 11:43:41 +08:00
mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576
2024-03-13 19:41:02 +08:00
self.data = []
2024-03-18 11:43:41 +08:00
self.length = []
2024-03-13 19:41:02 +08:00
meanings = np.random.randint(start, end, size=(size))
for m in meanings:
2024-03-18 11:43:41 +08:00
sq = mm.GetSequence(m)
if len(sq) >= min_seq_len:
2024-03-15 11:16:13 +08:00
self.data.append(sq)
2024-03-18 11:43:41 +08:00
self.length.append(len(sq))
2024-03-13 19:41:02 +08:00
def __len__(self):
2024-03-18 11:43:41 +08:00
return len(self.data)
def len(self):
return len(self.data)
2024-03-13 19:41:02 +08:00
def __getitem__(self, idx):
output = {}
2024-03-14 11:40:26 +08:00
data = torch.tensor(self.data[idx]).long()
2024-03-13 19:41:02 +08:00
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
return output
2024-03-26 11:32:02 +08:00
def GetBatch(self, index_list): # must equal sequence length
data = [self.data[i] for i in index_list]
2024-03-18 11:43:41 +08:00
output = {}
2024-03-26 11:32:02 +08:00
data = torch.tensor(np.stack(data, axis=0)).long()
2024-03-18 11:43:41 +08:00
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
return output
def Split(self, ratio):
l = len(self.data)
middle = int(l * ratio)
d_shuffle = self.data.copy()
l_shuffle = self.length.copy()
md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle])
md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:])
return md1, md2
class BatchGroupMeaningDataloader(Dataset):
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
length = dataset.length
unique, counts = np.unique(length, return_counts=True)
gl = {}
for u in unique:
gl[u] = np.where(length == u)[0]
lens = list(gl.keys())
gs = {}
if shuffle:
for k in gl.keys():
sl = gl[k].copy()
np.random.shuffle(sl)
gs[k] = sl
else:
for k in gl.keys():
sl = gl[k].copy()
gs[k] = sl
index = np.zeros((0, batch_size), dtype=np.int64)
for l in lens:
batch = len(gs[l]) // batch_size
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
index = np.concatenate((index, new), axis=0)
if shuffle:
index_shuffle = np.arange(0, index.shape[0])
np.random.shuffle(index_shuffle)
index = index[index_shuffle]
2024-03-20 22:27:28 +08:00
self.indexBatch = index
2024-03-18 11:43:41 +08:00
def __len__(self):
2024-03-20 22:27:28 +08:00
return len(self.indexBatch)
2024-03-18 11:43:41 +08:00
def __getitem__(self, idx):
# print("get idx" + str(idx))
2024-03-20 22:27:28 +08:00
return self.dataset.GetBatch(self.indexBatch[idx])
2024-03-18 11:43:41 +08:00
2024-03-13 19:41:02 +08:00
if __name__ == "__main__":
2024-03-18 11:43:41 +08:00
md = MeaningDataset(4096, 8100, size=1024)
train, val = md.Split(0.95)
dl = BatchGroupMeaningDataloader(train, 2)
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
dl = DataLoader(
train,
num_workers=1,
persistent_workers=True,
shuffle=False,
)
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
2024-03-13 19:41:02 +08:00
for i in range(10):
daf = next(it)["input_ids"].numpy().tolist()
print(daf)