Witllm/wit/meaning_dataset.py

507 lines
17 KiB
Python

import os
import datasets
import torch
import math
import random
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
import copy
class MeaningMap:
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True):
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
file = path + file + ".npz"
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(file) and use_cache:
print("Load from disk cache: " + file)
loaded = np.load(file)
slhwm = loaded["slhwm"]
dlra = loaded["dlra"]
self.ms_map = slhwm[:, 4:]
self.ms_data = dlra[:, 0]
self.ms_start = slhwm[:, 0]
self.ms_len = slhwm[:, 1]
self.ms_level = dlra[:, 1]
self.ms_rank_idx = dlra[:, 2].astype(np.uint32)
self.ms_rank_all = dlra[:, 3].astype(np.uint32)
self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3]
print("Load end")
else:
print("Disk cache miss, build new one.")
map = np.empty((size, max_subitem), dtype=np.int32)
index = np.arange(0, size)
map = np.random.random((size, max_subitem))
mask_zero = map.copy()
mask_zero[:, 0] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
item_sum = map.sum(axis=1)
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
map = (map * scale).astype(np.int32)
map[mask_zero] = -1
map[:vocab_size, 0] = np.arange(0, vocab_size)
map[:vocab_size, 1:] = -1
ms_data = [] # meaning sequence
ms_level = [] # meaning level, vocab's level is 0
ms_rank_idx = [] # meaning index of all level
ms_rank_all = [] # meaning all of all level
ms_start = [] # meaning sequence start
ms_len = [] # meaning sequence length
ms_height = [] # meaning tree height
ms_weight = [] # meaning tree weight
index = 0
for i in range(self.vocab_size):
ms_data.append(np.array([i]))
ms_level.append(np.array([0]))
ms_rank_idx.append(np.array([0]))
ms_rank_all.append(np.array([0]))
ms_start.append(index)
ms_len.append(1)
ms_height.append(0)
ms_weight.append(1)
index = index + 1
for i in range(self.vocab_size, size):
m = map[i]
m = m[m >= 0] # donot cut off the map such as [0]
m_list = m.tolist()
m_len = len(m_list)
assert m_list, "map list can not be empty list"
ma = np.concatenate([ms_data[newm] for newm in m_list])
ml = np.concatenate([ms_level[newm] + 1 for newm in m_list])
mr = np.concatenate(
[
([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
for i, newm in enumerate(m_list)
]
)
mrl = np.concatenate(
[
([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
for i, newm in enumerate(m_list)
]
)
# ml = ml[ma > 0] # cut off the 0 token, such as [12,32,0,42,32]
# mr = mr[ma > 0]
# mrl = mrl[ma > 0]
# ma = ma[ma > 0]
ms_data.append(ma)
ms_level.append(ml)
ms_rank_idx.append(mr)
ms_rank_all.append(mrl)
ms_start.append(index)
ms_len.append(len(ma))
ms_height.append(max([ms_height[sub_m] for sub_m in m_list]) + 1)
ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
index = index + len(ma)
ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
ms_rank_all = np.array(list(chain(*ms_rank_all))).astype(np.uint32)
d = np.ones(ms_rank_idx.shape, dtype=np.uint32)
d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32)
ms_rank_idx = (
((ms_rank_idx & 0xF) << 28)
+ ((ms_rank_idx & 0xF0) << 20)
+ ((ms_rank_idx & 0xF00) << 12)
+ ((ms_rank_idx & 0xF000) << 4)
+ ((ms_rank_idx & 0xF0000) >> 4)
+ ((ms_rank_idx & 0xF00000) >> 12)
+ ((ms_rank_idx & 0xF000000) >> 20)
+ ((ms_rank_idx & 0xF0000000) >> 28)
)
ms_rank_idx = ((ms_rank_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
ms_rank_all = (
((ms_rank_all & 0xF) << 28)
+ ((ms_rank_all & 0xF0) << 20)
+ ((ms_rank_all & 0xF00) << 12)
+ ((ms_rank_all & 0xF000) << 4)
+ ((ms_rank_all & 0xF0000) >> 4)
+ ((ms_rank_all & 0xF00000) >> 12)
+ ((ms_rank_all & 0xF000000) >> 20)
+ ((ms_rank_all & 0xF0000000) >> 28)
)
ms_rank_all = ((ms_rank_all >> ((8 - ms_level) * 4)) + d).astype(np.uint32)
ms_start = np.array(ms_start).astype(np.int32)
ms_height = np.array(ms_height).astype(np.int32)
ms_weight = np.array(ms_weight).astype(np.int32)
ms_len = np.array(ms_len).astype(np.int32)
slhwm = np.concatenate(
(
ms_start.reshape((-1, 1)),
ms_len.reshape((-1, 1)),
ms_height.reshape((-1, 1)),
ms_weight.reshape((-1, 1)),
map,
),
axis=1,
)
dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1)
np.savez(file, slhwm=slhwm, dlra=dlra)
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
self.ms_level = ms_level
self.ms_rank_idx = ms_rank_idx
self.ms_rank_all = ms_rank_all
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
self.ms_start = ms_start
self.ms_len = ms_len
self.ms_height = ms_height
self.ms_weight = ms_weight
print("Disk cache build end.")
def get_sequence(self, meaning): # return sequence[meaning]
start = self.ms_start[meaning]
len = self.ms_len[meaning]
return (
self.ms_data[start : start + len],
self.ms_level[start : start + len],
self.ms_rank_idx[start : start + len],
self.ms_rank_all[start : start + len],
)
def get_tree(self, meaning): # return meaning all sub items
tree = {}
ms = self.ms_map[meaning]
for m in ms[ms > 0].tolist():
tree[m] = self.get_tree(m) if m >= self.vocab_size else m
return tree
def max_length(self):
return max(self.ms_len)
def get_tree_str(tree, prefix):
if isinstance(tree, dict):
base = ""
last_is_dict = None
for key, value in tree.items():
new_prefix = (len(str(key)) + 2) * " " + prefix
dict_string = MeaningMap.get_tree_str(value, new_prefix)
if dict_string:
base += "\n" + prefix + str(key) + ": " + dict_string
last_is_dict = True
else:
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
last_is_dict = False
return base
return None
def token_frequency(tree, freq):
if isinstance(tree, dict):
for key, value in tree.items():
if key in freq:
freq[key] = freq[key] + 1
else:
freq[key] = 1
MeaningMap.token_frequency(value, freq)
class MeaningDataset(Dataset):
def __init__(
self,
start,
end,
size,
vocab_size,
max_subitem=10,
min_seq_len=2,
seed=42,
use_cache=True,
):
np.random.seed(seed)
map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache)
np.random.seed(seed)
self.tree = []
self.seq = []
self.level = []
self.rank_idx = []
self.rank_all = []
self.seq_meaning = []
self.m_height = map.ms_height
self.m_weight = map.ms_weight
meanings = np.random.randint(start, end, size=(size))
seq_len = []
for m in meanings:
d, l, i, a = map.get_sequence(m)
if len(d) >= min_seq_len:
self.tree.append({m: map.get_tree(m)})
self.seq.append(d)
self.level.append(l)
self.rank_idx.append(i)
self.rank_all.append(a)
self.seq_meaning.append(m)
seq_len.append(len(d))
unique, counts = np.unique(seq_len, return_counts=True)
print("----------------------------------------------------------------")
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
print("MeaningDataset size:" + str(len(seq_len)))
print("MeaningDataset max sequence length:" + str(max(unique)))
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
print("----------------------------------------------------------------")
def __len__(self):
return len(self.seq)
def len(self):
return len(self.seq)
def __getitem__(self, idx):
output = {}
data = torch.tensor(self.seq[idx]).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = self.tree[idx]
output["level"] = self.level[idx]
return output
def get_batch(self, idx_list): # must equal sequence length
data = [self.seq[i] for i in idx_list]
output = {}
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = [self.tree[i] for i in idx_list]
output["level"] = [self.level[i] for i in idx_list]
return output
def get_token(self, idx): # must equal sequence length
return self.seq[idx]
def get_tree(self, idx):
return self.tree[idx]
def print_tree(self, idx):
tokens = self.seq[idx]
tree = self.get_tree(idx)
s = str(tokens) + "\n"
s += MeaningMap.get_tree_str(tree, "")
return s
def copy(self, start, end):
new = copy.deepcopy(self)
new.tree = new.tree[start:end]
new.seq = new.seq[start:end]
new.level = new.level[start:end]
new.rank_idx = new.rank_idx[start:end]
new.rank_all = new.rank_all[start:end]
new.seq_meaning = new.seq_meaning[start:end]
return new
def split(self, ratio):
l = self.len()
middle = int(l * ratio)
return self.copy(0, middle), self.copy(middle, l)
def token_frequency(self):
freq = {}
for t in self.tree:
MeaningMap.token_frequency(t, freq)
return freq
def get_seq_mask(self, idx, level, index):
assert index < 15, "index must < 15"
assert level < 8, "level must < 8"
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
return rank_idx == (rank_all + index if index < 0 else index)
class BatchGroupMeaningDataloader(Dataset):
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
seq_len = [len(s) for s in dataset.seq]
unique, counts = np.unique(seq_len, return_counts=True)
gl = {}
for u in unique:
gl[u] = np.where(seq_len == u)[0]
lens = list(gl.keys())
gs = {}
if shuffle:
for k in gl.keys():
sl = gl[k].copy()
np.random.shuffle(sl)
gs[k] = sl
else:
for k in gl.keys():
sl = gl[k].copy()
gs[k] = sl
index = np.zeros((0, batch_size), dtype=np.int64)
for l in lens:
batch = len(gs[l]) // batch_size
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
index = np.concatenate((index, new), axis=0)
if shuffle:
index_shuffle = np.arange(0, index.shape[0])
np.random.shuffle(index_shuffle)
index = index[index_shuffle]
self.indexBatch = index
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size))
def __len__(self):
return len(self.indexBatch)
def __getitem__(self, idx):
return self.dataset.get_batch(self.indexBatch[idx])
def get_tree(self, idx):
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
def print_tree(self, idx):
idx_list = self.indexBatch[idx]
s = "--------------------------------------------------------\n"
for i in idx_list:
s += self.dataset.print_tree(i)
s += "--------------------------------------------------------\n"
return s
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
train, val = md.split(0.95)
fdaf = md.__getitem__(920)
print(md.print_tree(920))
print(md.rank_idx[920])
print(md.rank_all[920])
mask = md.get_seq_mask(920, 0, -1)
print(mask)
mask = md.get_seq_mask(920, 1, 0)
print(mask)
mask = md.get_seq_mask(920, 1, -1)
print(mask)
mask = md.get_seq_mask(920, 1, 1)
print(mask)
assert all(
np.equal(
mask[0:57],
np.array(
[
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
True,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
]
),
)
), "False"
freq = md.token_frequency()
dl = BatchGroupMeaningDataloader(train, 2)
# length = len(dl)
# it = iter(dl)
# ne1 = next(it)
# ne2 = next(it)
# ne3 = next(it)
# map1 = dl.get_tree(0)
# map2 = dl.get_tree(1)
# print(dl.print_tree(0))
# dl = DataLoader(
# train,
# num_workers=1,
# persistent_workers=True,
# shuffle=False,
# )
# it = iter(dl)
# ne1 = next(it)
# ne2 = next(it)
# ne3 = next(it)
# for i in range(10):
# print(next(it)["input_ids"].numpy().tolist())