307 lines
9.9 KiB
Python
307 lines
9.9 KiB
Python
import os
|
|
import datasets
|
|
import torch
|
|
import math
|
|
import random
|
|
from itertools import chain
|
|
from typing import Dict, Tuple
|
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
|
import numpy as np
|
|
from torch.utils.data import BatchSampler
|
|
|
|
|
|
class MeaningMap:
|
|
|
|
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
|
|
self.size = size
|
|
self.vocab_size = vocab_size
|
|
self.max_subitem = max_subitem
|
|
|
|
path = "./data/"
|
|
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
|
file = path + file
|
|
file_map = file + "_map" + ".npy"
|
|
file_start = file + "_start" + ".npy"
|
|
file_len = file + "_len" + ".npy"
|
|
file_data = file + "_data" + ".npy"
|
|
|
|
if not os.path.exists(path):
|
|
os.mkdir(path)
|
|
if (
|
|
os.path.exists(file_start)
|
|
and os.path.exists(file_len)
|
|
and os.path.exists(file_data)
|
|
and os.path.exists(file_map)
|
|
):
|
|
print("Load from disk cache: " + file)
|
|
self.ms_map = np.load(file_map)
|
|
self.ms_data = np.load(file_data)
|
|
self.ms_start = np.load(file_start)
|
|
self.ms_len = np.load(file_len)
|
|
print("Load end")
|
|
else:
|
|
print("Disk cache miss, build new one.")
|
|
|
|
mm = np.empty((size, max_subitem), dtype=np.int32)
|
|
|
|
index = np.arange(0, size)
|
|
mm = np.random.random((size, max_subitem))
|
|
|
|
mask_zero = mm.copy()
|
|
mask_zero[:, 0] = 0.0
|
|
mask_zero.sort(axis=1)
|
|
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
|
|
mask_zero = mask_zero > thre
|
|
|
|
item_sum = mm.sum(axis=1)
|
|
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
|
|
mm = mm * scale
|
|
mm[mask_zero] = 0
|
|
|
|
mm[:vocab_size, 0] = np.arange(0, vocab_size)
|
|
mm[:vocab_size, 1:] = 0
|
|
mm = mm.astype(np.int32)
|
|
|
|
ms = [] # meaning sequence
|
|
ms_start = [] # meaning sequence start
|
|
ms_len = [] # meaning sequence length
|
|
index = 0
|
|
for i in range(self.vocab_size):
|
|
ms.append([i])
|
|
ms_start.append(index)
|
|
ms_len.append(1)
|
|
index = index + 1
|
|
|
|
for i in range(self.vocab_size, size):
|
|
m = mm[i]
|
|
m = m[m > 0]
|
|
ma = []
|
|
for newm in m.tolist():
|
|
ma = ma + ms[newm]
|
|
ms.append(ma)
|
|
ms_start.append(index)
|
|
ms_len.append(len(ma))
|
|
index = index + len(ma)
|
|
|
|
ms_data = list(chain(*ms))
|
|
np.save(file_map, np.array(mm).astype(np.int32))
|
|
np.save(file_data, np.array(ms_data).astype(np.int32))
|
|
np.save(file_start, np.array(ms_start).astype(np.int32))
|
|
np.save(file_len, np.array(ms_len).astype(np.int32))
|
|
|
|
self.ms_map = mm
|
|
self.ms_data = ms_data
|
|
self.ms_start = ms_start
|
|
self.ms_len = ms_len
|
|
print("Disk cache build end.")
|
|
|
|
def get_sequence(self, meaning):
|
|
start = self.ms_start[meaning]
|
|
len = self.ms_len[meaning]
|
|
return self.ms_data[start : start + len]
|
|
|
|
def get_mapping(self, meaning):
|
|
mapping = {}
|
|
ms = self.ms_map[meaning]
|
|
for m in ms[ms > 0].tolist():
|
|
mapping[m] = self.get_mapping(m) if m >= self.vocab_size else m
|
|
return mapping
|
|
|
|
def max_length(self):
|
|
return max(self.ms_len)
|
|
|
|
|
|
class MeaningDataset(Dataset):
|
|
|
|
def __init__(
|
|
self,
|
|
start=131072,
|
|
end=1048576,
|
|
size=32768,
|
|
vocab_size=4096,
|
|
max_subitem=10,
|
|
min_seq_len=2,
|
|
seed=42,
|
|
data=None,
|
|
length=None,
|
|
mapping=None,
|
|
):
|
|
if data != None and length != None and mapping != None:
|
|
self.data = data
|
|
self.length = length
|
|
self.mapping = mapping
|
|
return
|
|
np.random.seed(seed)
|
|
mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576
|
|
self.mapping = []
|
|
self.data = []
|
|
self.length = []
|
|
meanings = np.random.randint(start, end, size=(size))
|
|
for m in meanings:
|
|
sq = mm.get_sequence(m)
|
|
if len(sq) >= min_seq_len:
|
|
self.mapping.append({m: mm.get_mapping(m)})
|
|
self.data.append(sq)
|
|
self.length.append(len(sq))
|
|
|
|
unique, counts = np.unique(self.length, return_counts=True)
|
|
print("----------------------------------------------------------------")
|
|
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
|
|
print("MeaningDataset size:" + str(len(self.length)))
|
|
print("MeaningDataset max sequence length:" + str(max(unique)))
|
|
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
|
|
print("----------------------------------------------------------------")
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def len(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
output = {}
|
|
data = torch.tensor(self.data[idx]).long()
|
|
output["input_ids"] = data
|
|
output["labels"] = data.clone()
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
|
return output
|
|
|
|
def get_batch(self, index_list): # must equal sequence length
|
|
data = [self.data[i] for i in index_list]
|
|
output = {}
|
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
|
output["input_ids"] = data
|
|
output["labels"] = data.clone()
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
|
return output
|
|
|
|
def get_token_batch(self, index_list): # must equal sequence length
|
|
return [self.data[i] for i in index_list]
|
|
|
|
def print_token_batch(self, index_list): # must equal sequence length
|
|
data = [self.data[i] for i in index_list]
|
|
output = {}
|
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
|
output["input_ids"] = data
|
|
output["labels"] = data.clone()
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
|
return output
|
|
|
|
def get_mapping_batch(self, index_list):
|
|
return [self.mapping[i] for i in index_list]
|
|
|
|
def __get_mapping_str__(map, prefix):
|
|
if isinstance(map, dict):
|
|
base = ""
|
|
for key, value in map.items():
|
|
base += prefix + str(key) + "\n"
|
|
base += MeaningDataset.__get_mapping_str__(value, prefix + " ")
|
|
return base
|
|
else:
|
|
return ""
|
|
|
|
def print_mapping_batch(self, index_list):
|
|
tokens = self.get_token_batch(index_list)
|
|
map = self.get_mapping_batch(index_list)
|
|
s = "--------------------------------------------------------\n"
|
|
for i, m in enumerate(map):
|
|
s += str(tokens[i]) + "\n"
|
|
s += MeaningDataset.__get_mapping_str__(m, "")
|
|
s += "--------------------------------------------------------\n"
|
|
return s
|
|
|
|
def split(self, ratio):
|
|
l = len(self.data)
|
|
middle = int(l * ratio)
|
|
d_shuffle = self.data.copy()
|
|
l_shuffle = self.length.copy()
|
|
m_shuffle = self.mapping.copy()
|
|
md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle], mapping=m_shuffle[:middle])
|
|
md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:], mapping=m_shuffle[middle:])
|
|
return md1, md2
|
|
|
|
|
|
class BatchGroupMeaningDataloader(Dataset):
|
|
|
|
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
|
self.dataset = dataset
|
|
self.batch_size = batch_size
|
|
self.drop_last = drop_last
|
|
|
|
length = dataset.length
|
|
unique, counts = np.unique(length, return_counts=True)
|
|
gl = {}
|
|
for u in unique:
|
|
gl[u] = np.where(length == u)[0]
|
|
|
|
lens = list(gl.keys())
|
|
gs = {}
|
|
if shuffle:
|
|
for k in gl.keys():
|
|
sl = gl[k].copy()
|
|
np.random.shuffle(sl)
|
|
gs[k] = sl
|
|
else:
|
|
for k in gl.keys():
|
|
sl = gl[k].copy()
|
|
gs[k] = sl
|
|
|
|
index = np.zeros((0, batch_size), dtype=np.int64)
|
|
for l in lens:
|
|
batch = len(gs[l]) // batch_size
|
|
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
|
|
index = np.concatenate((index, new), axis=0)
|
|
|
|
if shuffle:
|
|
index_shuffle = np.arange(0, index.shape[0])
|
|
np.random.shuffle(index_shuffle)
|
|
index = index[index_shuffle]
|
|
self.indexBatch = index
|
|
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
|
|
print("Dataloader total:" + str(len(length)) + " drop:" + str(len(length) - len(index) * batch_size))
|
|
|
|
def __len__(self):
|
|
return len(self.indexBatch)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.dataset.get_batch(self.indexBatch[idx])
|
|
|
|
def mapping(self, idx):
|
|
return self.dataset.get_mapping_batch(self.indexBatch[idx])
|
|
|
|
def print_mapping(self, idx):
|
|
return self.dataset.print_mapping_batch(self.indexBatch[idx])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024)
|
|
train, val = md.split(0.95)
|
|
|
|
dl = BatchGroupMeaningDataloader(train, 2)
|
|
length = len(dl)
|
|
it = iter(dl)
|
|
ne1 = next(it)
|
|
ne2 = next(it)
|
|
ne3 = next(it)
|
|
|
|
map1 = dl.mapping(0)
|
|
map2 = dl.mapping(1)
|
|
print(dl.print_mapping(0))
|
|
|
|
dl = DataLoader(
|
|
train,
|
|
num_workers=1,
|
|
persistent_workers=True,
|
|
shuffle=False,
|
|
)
|
|
it = iter(dl)
|
|
ne1 = next(it)
|
|
ne2 = next(it)
|
|
ne3 = next(it)
|
|
|
|
for i in range(10):
|
|
daf = next(it)["input_ids"].numpy().tolist()
|
|
|
|
print(daf)
|