Witllm/wit/meaning_dataset.py

307 lines
9.9 KiB
Python

import os
import datasets
import torch
import math
import random
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
class MeaningMap:
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
file = path + file
file_map = file + "_map" + ".npy"
file_start = file + "_start" + ".npy"
file_len = file + "_len" + ".npy"
file_data = file + "_data" + ".npy"
if not os.path.exists(path):
os.mkdir(path)
if (
os.path.exists(file_start)
and os.path.exists(file_len)
and os.path.exists(file_data)
and os.path.exists(file_map)
):
print("Load from disk cache: " + file)
self.ms_map = np.load(file_map)
self.ms_data = np.load(file_data)
self.ms_start = np.load(file_start)
self.ms_len = np.load(file_len)
print("Load end")
else:
print("Disk cache miss, build new one.")
mm = np.empty((size, max_subitem), dtype=np.int32)
index = np.arange(0, size)
mm = np.random.random((size, max_subitem))
mask_zero = mm.copy()
mask_zero[:, 0] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
item_sum = mm.sum(axis=1)
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
mm = mm * scale
mm[mask_zero] = 0
mm[:vocab_size, 0] = np.arange(0, vocab_size)
mm[:vocab_size, 1:] = 0
mm = mm.astype(np.int32)
ms = [] # meaning sequence
ms_start = [] # meaning sequence start
ms_len = [] # meaning sequence length
index = 0
for i in range(self.vocab_size):
ms.append([i])
ms_start.append(index)
ms_len.append(1)
index = index + 1
for i in range(self.vocab_size, size):
m = mm[i]
m = m[m > 0]
ma = []
for newm in m.tolist():
ma = ma + ms[newm]
ms.append(ma)
ms_start.append(index)
ms_len.append(len(ma))
index = index + len(ma)
ms_data = list(chain(*ms))
np.save(file_map, np.array(mm).astype(np.int32))
np.save(file_data, np.array(ms_data).astype(np.int32))
np.save(file_start, np.array(ms_start).astype(np.int32))
np.save(file_len, np.array(ms_len).astype(np.int32))
self.ms_map = mm
self.ms_data = ms_data
self.ms_start = ms_start
self.ms_len = ms_len
print("Disk cache build end.")
def get_sequence(self, meaning):
start = self.ms_start[meaning]
len = self.ms_len[meaning]
return self.ms_data[start : start + len]
def get_mapping(self, meaning):
mapping = {}
ms = self.ms_map[meaning]
for m in ms[ms > 0].tolist():
mapping[m] = self.get_mapping(m) if m >= self.vocab_size else m
return mapping
def max_length(self):
return max(self.ms_len)
class MeaningDataset(Dataset):
def __init__(
self,
start=131072,
end=1048576,
size=32768,
vocab_size=4096,
max_subitem=10,
min_seq_len=2,
seed=42,
data=None,
length=None,
mapping=None,
):
if data != None and length != None and mapping != None:
self.data = data
self.length = length
self.mapping = mapping
return
np.random.seed(seed)
mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576
self.mapping = []
self.data = []
self.length = []
meanings = np.random.randint(start, end, size=(size))
for m in meanings:
sq = mm.get_sequence(m)
if len(sq) >= min_seq_len:
self.mapping.append({m: mm.get_mapping(m)})
self.data.append(sq)
self.length.append(len(sq))
unique, counts = np.unique(self.length, return_counts=True)
print("----------------------------------------------------------------")
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
print("MeaningDataset size:" + str(len(self.length)))
print("MeaningDataset max sequence length:" + str(max(unique)))
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
print("----------------------------------------------------------------")
def __len__(self):
return len(self.data)
def len(self):
return len(self.data)
def __getitem__(self, idx):
output = {}
data = torch.tensor(self.data[idx]).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
return output
def get_batch(self, index_list): # must equal sequence length
data = [self.data[i] for i in index_list]
output = {}
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
return output
def get_token_batch(self, index_list): # must equal sequence length
return [self.data[i] for i in index_list]
def print_token_batch(self, index_list): # must equal sequence length
data = [self.data[i] for i in index_list]
output = {}
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
return output
def get_mapping_batch(self, index_list):
return [self.mapping[i] for i in index_list]
def __get_mapping_str__(map, prefix):
if isinstance(map, dict):
base = ""
for key, value in map.items():
base += prefix + str(key) + "\n"
base += MeaningDataset.__get_mapping_str__(value, prefix + " ")
return base
else:
return ""
def print_mapping_batch(self, index_list):
tokens = self.get_token_batch(index_list)
map = self.get_mapping_batch(index_list)
s = "--------------------------------------------------------\n"
for i, m in enumerate(map):
s += str(tokens[i]) + "\n"
s += MeaningDataset.__get_mapping_str__(m, "")
s += "--------------------------------------------------------\n"
return s
def split(self, ratio):
l = len(self.data)
middle = int(l * ratio)
d_shuffle = self.data.copy()
l_shuffle = self.length.copy()
m_shuffle = self.mapping.copy()
md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle], mapping=m_shuffle[:middle])
md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:], mapping=m_shuffle[middle:])
return md1, md2
class BatchGroupMeaningDataloader(Dataset):
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
length = dataset.length
unique, counts = np.unique(length, return_counts=True)
gl = {}
for u in unique:
gl[u] = np.where(length == u)[0]
lens = list(gl.keys())
gs = {}
if shuffle:
for k in gl.keys():
sl = gl[k].copy()
np.random.shuffle(sl)
gs[k] = sl
else:
for k in gl.keys():
sl = gl[k].copy()
gs[k] = sl
index = np.zeros((0, batch_size), dtype=np.int64)
for l in lens:
batch = len(gs[l]) // batch_size
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
index = np.concatenate((index, new), axis=0)
if shuffle:
index_shuffle = np.arange(0, index.shape[0])
np.random.shuffle(index_shuffle)
index = index[index_shuffle]
self.indexBatch = index
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
print("Dataloader total:" + str(len(length)) + " drop:" + str(len(length) - len(index) * batch_size))
def __len__(self):
return len(self.indexBatch)
def __getitem__(self, idx):
return self.dataset.get_batch(self.indexBatch[idx])
def mapping(self, idx):
return self.dataset.get_mapping_batch(self.indexBatch[idx])
def print_mapping(self, idx):
return self.dataset.print_mapping_batch(self.indexBatch[idx])
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024)
train, val = md.split(0.95)
dl = BatchGroupMeaningDataloader(train, 2)
length = len(dl)
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
map1 = dl.mapping(0)
map2 = dl.mapping(1)
print(dl.print_mapping(0))
dl = DataLoader(
train,
num_workers=1,
persistent_workers=True,
shuffle=False,
)
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
for i in range(10):
daf = next(it)["input_ids"].numpy().tolist()
print(daf)