Witllm/wit/meaning_dataset.py

428 lines
15 KiB
Python

import os
import datasets
import torch
import math
import random
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
class MeaningMap:
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True):
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
file = path + file
file_slhwm = file + "_slhwm" + ".npy"
file_dli = file + "_dli" + ".npy"
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache:
print("Load from disk cache: " + file)
slhwm = np.load(file_slhwm)
dli = np.load(file_dli)
self.ms_map = slhwm[:, 4:]
self.ms_data = dli[:, 0]
self.ms_start = slhwm[:, 0]
self.ms_len = slhwm[:, 1]
self.ms_level = dli[:, 1]
self.ms_idx = dli[:, 2].astype(np.uint32)
self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3]
print("Load end")
else:
print("Disk cache miss, build new one.")
map = np.empty((size, max_subitem), dtype=np.int32)
index = np.arange(0, size)
map = np.random.random((size, max_subitem))
mask_zero = map.copy()
mask_zero[:, 0] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
item_sum = map.sum(axis=1)
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
map = (map * scale).astype(np.int32)
map[mask_zero] = -1
map[:vocab_size, 0] = np.arange(0, vocab_size)
map[:vocab_size, 1:] = -1
ms_data = [] # meaning sequence
ms_level = [] # meaning level, vocab's level is 0
ms_idx = [] # meaning index of lowest level
ms_start = [] # meaning sequence start
ms_len = [] # meaning sequence length
ms_height = [] # meaning tree height
ms_weight = [] # meaning tree weight
index = 0
for i in range(self.vocab_size):
ms_data.append(np.array([i]))
ms_level.append(np.array([0]))
ms_idx.append(np.array([0]))
ms_start.append(index)
ms_len.append(1)
ms_height.append(0)
ms_weight.append(1)
index = index + 1
for i in range(self.vocab_size, size):
m = map[i]
m = m[m >= 0]
m_list = m.tolist()
assert m_list, "map list can not be empty list"
ma = np.concatenate([ms_data[newm] for newm in m_list])
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
mi = np.concatenate(
[
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_idx[newm] * 16 + i)
for i, newm in enumerate(m_list)
]
)
ml = ml[ma > 0]
mi = mi[ma > 0]
ma = ma[ma > 0]
ms_data.append(ma)
ms_level.append(ml)
ms_idx.append(mi)
ms_start.append(index)
ms_len.append(len(ma))
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
index = index + len(ma)
# offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28]
# for idxmi, mi in enumerate(ms_idx):
# level = ms_level[idxmi]
# for idxnum, num in enumerate(mi):
# l = level[idxnum]
# elements = [(num >> offset) & 0xF for offset in offsets[l:0:-1]]
# num = (num >> (l * 4)) << (l * 4)
# num += sum(elem << (i * 4) for i, elem in enumerate(elements))
# mi[idxnum] = num
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
ms_idx = np.array(list(chain(*ms_idx))).astype(np.uint32)
d = np.ones(ms_idx.shape, dtype=np.uint32)
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
ms_idx = (
((ms_idx & 0xF) << 28)
+ ((ms_idx & 0xF0) << 20)
+ ((ms_idx & 0xF00) << 12)
+ ((ms_idx & 0xF000) << 4)
+ ((ms_idx & 0xF0000) >> 4)
+ ((ms_idx & 0xF00000) >> 12)
+ ((ms_idx & 0xF000000) >> 20)
+ ((ms_idx & 0xF0000000) >> 28)
)
ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
ms_start = np.array(ms_start).astype(np.int32)
ms_height = np.array(ms_height).astype(np.int32)
ms_weight = np.array(ms_weight).astype(np.int32)
ms_len = np.array(ms_len).astype(np.int32)
slhwm = np.concatenate(
(
ms_start.reshape((-1, 1)),
ms_len.reshape((-1, 1)),
ms_height.reshape((-1, 1)),
ms_weight.reshape((-1, 1)),
map,
),
axis=1,
)
dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
np.save(file_slhwm, slhwm)
np.save(file_dli, dli)
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
self.ms_start = ms_start
self.ms_len = ms_len
self.ms_level = ms_level
self.ms_idx = ms_idx
self.ms_height = ms_height
self.ms_weight = ms_weight
print("Disk cache build end.")
def get_sequence(self, meaning): # return sequence[meaning]
start = self.ms_start[meaning]
len = self.ms_len[meaning]
return self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_idx[start : start + len]
def get_tree(self, meaning): # return meaning all sub items
tree = {}
ms = self.ms_map[meaning]
for m in ms[ms > 0].tolist():
tree[m] = self.get_tree(m) if m >= self.vocab_size else m
return tree
def max_length(self):
return max(self.ms_len)
def get_tree_str(tree, prefix):
if isinstance(tree, dict):
base = ""
last_is_dict = None
for key, value in tree.items():
new_prefix = (len(str(key)) + 2) * " " + prefix
dict_string = MeaningMap.get_tree_str(value, new_prefix)
if dict_string:
base += "\n" + prefix + str(key) + ": " + dict_string
last_is_dict = True
else:
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
last_is_dict = False
return base
return None
def token_frequency(tree, freq):
if isinstance(tree, dict):
for key, value in tree.items():
if key in freq:
freq[key] = freq[key] + 1
else:
freq[key] = 1
MeaningMap.token_frequency(value, freq)
class MeaningDataset(Dataset):
def __init__(
self,
start=131072,
end=1048576,
size=32768,
vocab_size=4096,
max_subitem=10,
min_seq_len=2,
seed=42,
data=None,
length=None,
tree=None,
level=None,
idx=None,
use_cache=True,
):
if data != None and length != None and tree != None and level != None and idx != None:
self.data = data
self.length = length
self.tree = tree
self.level = level
self.idx = idx
return
np.random.seed(seed)
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
self.tree = []
self.data = []
self.level = []
self.idx = []
self.length = []
meanings = np.random.randint(start, end, size=(size))
for m in meanings:
d, l, i = map.get_sequence(m)
if len(d) >= min_seq_len:
self.tree.append({m: map.get_tree(m)})
self.data.append(d)
self.level.append(l)
self.idx.append(i)
self.length.append(len(d))
unique, counts = np.unique(self.length, return_counts=True)
print("----------------------------------------------------------------")
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
print("MeaningDataset size:" + str(len(self.length)))
print("MeaningDataset max sequence length:" + str(max(unique)))
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
print("----------------------------------------------------------------")
def __len__(self):
return len(self.data)
def len(self):
return len(self.data)
def __getitem__(self, idx):
output = {}
data = torch.tensor(self.data[idx]).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = self.tree[idx]
output["level"] = self.level[idx]
output["idx"] = self.idx[idx]
return output
def get_batch(self, idx_list): # must equal sequence length
data = [self.data[i] for i in idx_list]
output = {}
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = [self.tree[i] for i in idx_list]
output["level"] = [self.level[i] for i in idx_list]
output["idx"] = [self.idx[i] for i in idx_list]
return output
def get_token(self, idx): # must equal sequence length
return self.data[idx]
def get_tree(self, idx):
return self.tree[idx]
def print_tree(self, idx):
tokens = self.data[idx]
tree = self.get_tree(idx)
s = str(tokens) + "\n"
s += MeaningMap.get_tree_str(tree, "")
return s
def split(self, ratio):
l = len(self.data)
middle = int(l * ratio)
d_shuffle = self.data.copy()
l_shuffle = self.length.copy()
m_shuffle = self.tree.copy()
level_shuffle = self.level.copy()
i_shuffle = self.idx.copy()
md1 = MeaningDataset(
data=d_shuffle[:middle],
length=l_shuffle[:middle],
tree=m_shuffle[:middle],
level=level_shuffle[:middle],
idx=i_shuffle[:middle],
)
md2 = MeaningDataset(
data=d_shuffle[middle:],
length=l_shuffle[middle:],
tree=m_shuffle[middle:],
level=level_shuffle[middle:],
idx=i_shuffle[middle:],
)
return md1, md2
def token_frequency(self):
freq = {}
for t in self.tree:
MeaningMap.token_frequency(t, freq)
return freq
def get_seq_mask(idx, level, index):
assert index < 15, "index must < 15"
assert level < 8, "level must < 8"
return [((int(i / (16**level)) & 0xF) == index) for i in idx]
class BatchGroupMeaningDataloader(Dataset):
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
length = dataset.length
unique, counts = np.unique(length, return_counts=True)
gl = {}
for u in unique:
gl[u] = np.where(length == u)[0]
lens = list(gl.keys())
gs = {}
if shuffle:
for k in gl.keys():
sl = gl[k].copy()
np.random.shuffle(sl)
gs[k] = sl
else:
for k in gl.keys():
sl = gl[k].copy()
gs[k] = sl
index = np.zeros((0, batch_size), dtype=np.int64)
for l in lens:
batch = len(gs[l]) // batch_size
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
index = np.concatenate((index, new), axis=0)
if shuffle:
index_shuffle = np.arange(0, index.shape[0])
np.random.shuffle(index_shuffle)
index = index[index_shuffle]
self.indexBatch = index
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
print("Dataloader total:" + str(len(length)) + " drop:" + str(len(length) - len(index) * batch_size))
def __len__(self):
return len(self.indexBatch)
def __getitem__(self, idx):
return self.dataset.get_batch(self.indexBatch[idx])
def get_tree(self, idx):
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
def print_tree(self, idx):
idx_list = self.indexBatch[idx]
s = "--------------------------------------------------------\n"
for i in idx_list:
s += self.dataset.print_tree(i)
s += "--------------------------------------------------------\n"
return s
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False)
train, val = md.split(0.95)
fdaf = md.__getitem__(920)
print(md.print_tree(920))
print(md.idx[920])
fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
print(fdasfe)
freq = md.token_frequency()
dl = BatchGroupMeaningDataloader(train, 2)
length = len(dl)
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
map1 = dl.get_tree(0)
map2 = dl.get_tree(1)
print(dl.print_tree(0))
dl = DataLoader(
train,
num_workers=1,
persistent_workers=True,
shuffle=False,
)
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
for i in range(10):
daf = next(it)["input_ids"].numpy().tolist()
print(daf)