Witllm/wit/meaning_dataset.py

428 lines
15 KiB
Python
Raw Normal View History

2024-03-13 19:41:02 +08:00
import os
import datasets
import torch
import math
import random
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
2024-03-18 11:43:41 +08:00
from torch.utils.data import BatchSampler
2024-03-13 19:41:02 +08:00
2024-04-02 19:59:05 +08:00
class MeaningMap:
2024-04-07 00:25:21 +08:00
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True):
2024-03-13 19:41:02 +08:00
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
2024-03-26 11:32:02 +08:00
path = "./data/"
2024-03-13 19:41:02 +08:00
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
2024-03-26 11:32:02 +08:00
file = path + file
2024-04-07 00:25:21 +08:00
file_slhwm = file + "_slhwm" + ".npy"
file_dli = file + "_dli" + ".npy"
2024-03-13 19:41:02 +08:00
2024-03-26 11:32:02 +08:00
if not os.path.exists(path):
os.mkdir(path)
2024-04-07 00:25:21 +08:00
if os.path.exists(file_slhwm) and os.path.exists(file_dli) and use_cache:
2024-03-13 19:41:02 +08:00
print("Load from disk cache: " + file)
2024-04-07 00:25:21 +08:00
slhwm = np.load(file_slhwm)
dli = np.load(file_dli)
self.ms_map = slhwm[:, 4:]
self.ms_data = dli[:, 0]
self.ms_start = slhwm[:, 0]
self.ms_len = slhwm[:, 1]
self.ms_level = dli[:, 1]
self.ms_idx = dli[:, 2].astype(np.uint32)
self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3]
2024-03-26 18:15:55 +08:00
print("Load end")
2024-03-18 11:43:41 +08:00
else:
print("Disk cache miss, build new one.")
2024-04-07 17:03:35 +08:00
map = np.empty((size, max_subitem), dtype=np.int32)
2024-03-18 11:43:41 +08:00
index = np.arange(0, size)
2024-04-07 00:25:21 +08:00
map = np.random.random((size, max_subitem))
2024-03-18 11:43:41 +08:00
2024-04-07 00:25:21 +08:00
mask_zero = map.copy()
2024-03-18 11:43:41 +08:00
mask_zero[:, 0] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
2024-04-07 00:25:21 +08:00
item_sum = map.sum(axis=1)
2024-03-18 11:43:41 +08:00
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
2024-04-07 17:03:35 +08:00
map = (map * scale).astype(np.int32)
2024-04-07 00:25:21 +08:00
2024-04-07 17:03:35 +08:00
map[mask_zero] = -1
2024-03-18 11:43:41 +08:00
2024-04-07 00:25:21 +08:00
map[:vocab_size, 0] = np.arange(0, vocab_size)
2024-04-07 17:03:35 +08:00
map[:vocab_size, 1:] = -1
2024-03-18 11:43:41 +08:00
2024-04-07 00:25:21 +08:00
ms_data = [] # meaning sequence
ms_level = [] # meaning level, vocab's level is 0
ms_idx = [] # meaning index of lowest level
2024-03-18 11:43:41 +08:00
ms_start = [] # meaning sequence start
ms_len = [] # meaning sequence length
2024-04-07 00:25:21 +08:00
ms_height = [] # meaning tree height
ms_weight = [] # meaning tree weight
2024-03-18 11:43:41 +08:00
index = 0
for i in range(self.vocab_size):
2024-04-07 17:03:35 +08:00
ms_data.append(np.array([i]))
ms_level.append(np.array([0]))
ms_idx.append(np.array([0]))
2024-03-18 11:43:41 +08:00
ms_start.append(index)
ms_len.append(1)
2024-04-07 00:25:21 +08:00
ms_height.append(0)
ms_weight.append(1)
2024-04-07 17:03:35 +08:00
index = index + 1
2024-03-18 11:43:41 +08:00
for i in range(self.vocab_size, size):
2024-04-07 00:25:21 +08:00
m = map[i]
2024-04-07 17:03:35 +08:00
m = m[m >= 0]
m_list = m.tolist()
assert m_list, "map list can not be empty list"
ma = np.concatenate([ms_data[newm] for newm in m_list])
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
mi = np.concatenate(
[
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_idx[newm] * 16 + i)
for i, newm in enumerate(m_list)
]
)
ml = ml[ma > 0]
mi = mi[ma > 0]
ma = ma[ma > 0]
2024-04-07 00:25:21 +08:00
ms_data.append(ma)
ms_level.append(ml)
ms_idx.append(mi)
2024-04-07 17:03:35 +08:00
ms_start.append(index)
ms_len.append(len(ma))
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
2024-03-18 11:43:41 +08:00
index = index + len(ma)
2024-04-07 00:25:21 +08:00
# offsets = [0, 0, 4, 8, 12, 16, 20, 24, 28]
# for idxmi, mi in enumerate(ms_idx):
# level = ms_level[idxmi]
# for idxnum, num in enumerate(mi):
# l = level[idxnum]
# elements = [(num >> offset) & 0xF for offset in offsets[l:0:-1]]
# num = (num >> (l * 4)) << (l * 4)
# num += sum(elem << (i * 4) for i, elem in enumerate(elements))
# mi[idxnum] = num
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
ms_idx = np.array(list(chain(*ms_idx))).astype(np.uint32)
d = np.ones(ms_idx.shape, dtype=np.uint32)
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
ms_idx = (
((ms_idx & 0xF) << 28)
+ ((ms_idx & 0xF0) << 20)
+ ((ms_idx & 0xF00) << 12)
+ ((ms_idx & 0xF000) << 4)
+ ((ms_idx & 0xF0000) >> 4)
+ ((ms_idx & 0xF00000) >> 12)
+ ((ms_idx & 0xF000000) >> 20)
+ ((ms_idx & 0xF0000000) >> 28)
)
ms_idx = ((ms_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
2024-04-07 17:03:35 +08:00
ms_start = np.array(ms_start).astype(np.int32)
ms_height = np.array(ms_height).astype(np.int32)
ms_weight = np.array(ms_weight).astype(np.int32)
ms_len = np.array(ms_len).astype(np.int32)
2024-04-07 00:25:21 +08:00
slhwm = np.concatenate(
(
ms_start.reshape((-1, 1)),
ms_len.reshape((-1, 1)),
ms_height.reshape((-1, 1)),
ms_weight.reshape((-1, 1)),
2024-04-07 17:03:35 +08:00
map,
2024-04-07 00:25:21 +08:00
),
axis=1,
)
dli = np.stack((ms_data, ms_level, ms_idx.astype(np.int32)), axis=1)
np.save(file_slhwm, slhwm)
np.save(file_dli, dli)
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
2024-03-18 11:43:41 +08:00
self.ms_start = ms_start
self.ms_len = ms_len
2024-04-07 00:25:21 +08:00
self.ms_level = ms_level
self.ms_idx = ms_idx
self.ms_height = ms_height
self.ms_weight = ms_weight
2024-03-18 11:43:41 +08:00
print("Disk cache build end.")
2024-03-13 19:41:02 +08:00
2024-04-07 00:25:21 +08:00
def get_sequence(self, meaning): # return sequence[meaning]
2024-03-13 19:41:02 +08:00
start = self.ms_start[meaning]
len = self.ms_len[meaning]
2024-04-07 00:25:21 +08:00
return self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_idx[start : start + len]
2024-03-13 19:41:02 +08:00
2024-04-07 00:25:21 +08:00
def get_tree(self, meaning): # return meaning all sub items
tree = {}
2024-04-02 19:59:05 +08:00
ms = self.ms_map[meaning]
for m in ms[ms > 0].tolist():
2024-04-07 00:25:21 +08:00
tree[m] = self.get_tree(m) if m >= self.vocab_size else m
return tree
2024-04-02 19:59:05 +08:00
def max_length(self):
2024-03-18 11:43:41 +08:00
return max(self.ms_len)
2024-04-07 00:25:21 +08:00
def get_tree_str(tree, prefix):
if isinstance(tree, dict):
base = ""
last_is_dict = None
for key, value in tree.items():
new_prefix = (len(str(key)) + 2) * " " + prefix
dict_string = MeaningMap.get_tree_str(value, new_prefix)
if dict_string:
base += "\n" + prefix + str(key) + ": " + dict_string
last_is_dict = True
else:
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
last_is_dict = False
return base
return None
2024-03-13 19:41:02 +08:00
2024-04-07 00:25:21 +08:00
def token_frequency(tree, freq):
if isinstance(tree, dict):
for key, value in tree.items():
if key in freq:
freq[key] = freq[key] + 1
else:
freq[key] = 1
MeaningMap.token_frequency(value, freq)
2024-03-13 19:41:02 +08:00
2024-04-07 00:25:21 +08:00
class MeaningDataset(Dataset):
2024-03-18 11:43:41 +08:00
def __init__(
self,
start=131072,
end=1048576,
size=32768,
vocab_size=4096,
max_subitem=10,
min_seq_len=2,
seed=42,
data=None,
length=None,
2024-04-07 00:25:21 +08:00
tree=None,
level=None,
idx=None,
use_cache=True,
2024-03-18 11:43:41 +08:00
):
2024-04-07 00:25:21 +08:00
if data != None and length != None and tree != None and level != None and idx != None:
2024-03-18 11:43:41 +08:00
self.data = data
self.length = length
2024-04-07 00:25:21 +08:00
self.tree = tree
self.level = level
self.idx = idx
2024-03-18 11:43:41 +08:00
return
2024-03-13 19:41:02 +08:00
np.random.seed(seed)
2024-04-07 00:25:21 +08:00
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
self.tree = []
2024-03-13 19:41:02 +08:00
self.data = []
2024-04-07 00:25:21 +08:00
self.level = []
self.idx = []
2024-03-18 11:43:41 +08:00
self.length = []
2024-03-13 19:41:02 +08:00
meanings = np.random.randint(start, end, size=(size))
for m in meanings:
2024-04-07 00:25:21 +08:00
d, l, i = map.get_sequence(m)
if len(d) >= min_seq_len:
self.tree.append({m: map.get_tree(m)})
self.data.append(d)
self.level.append(l)
self.idx.append(i)
self.length.append(len(d))
2024-03-13 19:41:02 +08:00
2024-03-26 18:15:55 +08:00
unique, counts = np.unique(self.length, return_counts=True)
2024-03-29 22:10:25 +08:00
print("----------------------------------------------------------------")
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
print("MeaningDataset size:" + str(len(self.length)))
print("MeaningDataset max sequence length:" + str(max(unique)))
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
print("----------------------------------------------------------------")
2024-03-26 18:15:55 +08:00
2024-03-13 19:41:02 +08:00
def __len__(self):
2024-03-18 11:43:41 +08:00
return len(self.data)
def len(self):
return len(self.data)
2024-03-13 19:41:02 +08:00
def __getitem__(self, idx):
output = {}
2024-03-14 11:40:26 +08:00
data = torch.tensor(self.data[idx]).long()
2024-03-13 19:41:02 +08:00
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
2024-04-07 00:25:21 +08:00
output["tree"] = self.tree[idx]
output["level"] = self.level[idx]
output["idx"] = self.idx[idx]
2024-03-13 19:41:02 +08:00
return output
2024-04-07 00:25:21 +08:00
def get_batch(self, idx_list): # must equal sequence length
data = [self.data[i] for i in idx_list]
2024-04-03 13:03:59 +08:00
output = {}
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
2024-04-07 00:25:21 +08:00
output["tree"] = [self.tree[i] for i in idx_list]
output["level"] = [self.level[i] for i in idx_list]
output["idx"] = [self.idx[i] for i in idx_list]
2024-04-03 13:03:59 +08:00
return output
2024-04-07 00:25:21 +08:00
def get_token(self, idx): # must equal sequence length
return self.data[idx]
2024-04-02 19:59:05 +08:00
2024-04-07 00:25:21 +08:00
def get_tree(self, idx):
return self.tree[idx]
2024-04-03 13:03:59 +08:00
2024-04-07 00:25:21 +08:00
def print_tree(self, idx):
tokens = self.data[idx]
tree = self.get_tree(idx)
s = str(tokens) + "\n"
s += MeaningMap.get_tree_str(tree, "")
2024-04-03 13:03:59 +08:00
return s
2024-04-02 19:59:05 +08:00
def split(self, ratio):
2024-03-18 11:43:41 +08:00
l = len(self.data)
middle = int(l * ratio)
d_shuffle = self.data.copy()
l_shuffle = self.length.copy()
2024-04-07 00:25:21 +08:00
m_shuffle = self.tree.copy()
level_shuffle = self.level.copy()
i_shuffle = self.idx.copy()
md1 = MeaningDataset(
data=d_shuffle[:middle],
length=l_shuffle[:middle],
tree=m_shuffle[:middle],
level=level_shuffle[:middle],
idx=i_shuffle[:middle],
)
md2 = MeaningDataset(
data=d_shuffle[middle:],
length=l_shuffle[middle:],
tree=m_shuffle[middle:],
level=level_shuffle[middle:],
idx=i_shuffle[middle:],
)
2024-03-18 11:43:41 +08:00
return md1, md2
2024-04-07 00:25:21 +08:00
def token_frequency(self):
freq = {}
for t in self.tree:
MeaningMap.token_frequency(t, freq)
return freq
2024-03-18 11:43:41 +08:00
2024-04-07 00:25:21 +08:00
def get_seq_mask(idx, level, index):
assert index < 15, "index must < 15"
assert level < 8, "level must < 8"
return [((int(i / (16**level)) & 0xF) == index) for i in idx]
2024-03-18 11:43:41 +08:00
2024-04-07 00:25:21 +08:00
class BatchGroupMeaningDataloader(Dataset):
2024-03-18 11:43:41 +08:00
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
length = dataset.length
unique, counts = np.unique(length, return_counts=True)
gl = {}
for u in unique:
gl[u] = np.where(length == u)[0]
lens = list(gl.keys())
gs = {}
if shuffle:
for k in gl.keys():
sl = gl[k].copy()
np.random.shuffle(sl)
gs[k] = sl
else:
for k in gl.keys():
sl = gl[k].copy()
gs[k] = sl
index = np.zeros((0, batch_size), dtype=np.int64)
for l in lens:
batch = len(gs[l]) // batch_size
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
index = np.concatenate((index, new), axis=0)
2024-04-02 19:59:05 +08:00
2024-03-18 11:43:41 +08:00
if shuffle:
index_shuffle = np.arange(0, index.shape[0])
np.random.shuffle(index_shuffle)
index = index[index_shuffle]
2024-03-20 22:27:28 +08:00
self.indexBatch = index
2024-03-29 22:10:25 +08:00
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
print("Dataloader total:" + str(len(length)) + " drop:" + str(len(length) - len(index) * batch_size))
2024-03-18 11:43:41 +08:00
def __len__(self):
2024-03-20 22:27:28 +08:00
return len(self.indexBatch)
2024-03-18 11:43:41 +08:00
def __getitem__(self, idx):
2024-04-02 19:59:05 +08:00
return self.dataset.get_batch(self.indexBatch[idx])
2024-04-07 00:25:21 +08:00
def get_tree(self, idx):
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
2024-03-18 11:43:41 +08:00
2024-04-07 00:25:21 +08:00
def print_tree(self, idx):
idx_list = self.indexBatch[idx]
s = "--------------------------------------------------------\n"
for i in idx_list:
s += self.dataset.print_tree(i)
s += "--------------------------------------------------------\n"
return s
2024-04-03 11:24:00 +08:00
2024-03-13 19:41:02 +08:00
if __name__ == "__main__":
2024-04-07 17:03:35 +08:00
md = MeaningDataset(100000, 115200, vocab_size=128, size=1024, use_cache=False)
2024-04-02 19:59:05 +08:00
train, val = md.split(0.95)
2024-04-07 00:25:21 +08:00
fdaf = md.__getitem__(920)
print(md.print_tree(920))
print(md.idx[920])
fdasfe = MeaningDataset.get_seq_mask(md.idx[920], 1, 1)
print(fdasfe)
freq = md.token_frequency()
2024-03-18 11:43:41 +08:00
2024-04-02 19:59:05 +08:00
dl = BatchGroupMeaningDataloader(train, 2)
2024-03-26 18:15:55 +08:00
length = len(dl)
2024-03-18 11:43:41 +08:00
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
2024-04-07 00:25:21 +08:00
map1 = dl.get_tree(0)
map2 = dl.get_tree(1)
print(dl.print_tree(0))
2024-04-02 19:59:05 +08:00
2024-03-18 11:43:41 +08:00
dl = DataLoader(
train,
num_workers=1,
persistent_workers=True,
shuffle=False,
)
it = iter(dl)
ne1 = next(it)
ne2 = next(it)
ne3 = next(it)
2024-03-13 19:41:02 +08:00
for i in range(10):
daf = next(it)["input_ids"].numpy().tolist()
print(daf)