Witllm/wit/meaning_dataset.py

596 lines
21 KiB
Python

import os
import torch, datasets
import math, gc, time, random, copy
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
# import warnings
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
class MeaningMap:
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True):
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
assert min_subitem <= max_subitem, "Invalid input"
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
datastep = 0x8000000
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size)
file += "_" + str(max_subitem) + "_" + str(min_subitem)
file_prop = path + file + "_prop.npy"
file_data = path + file + "_data.npy"
file_level = path + file + "_level.npy"
file_rank_idx = path + file + "_rank_idx.npy"
file_rank_all = path + file + "_rank_all.npy"
start_time = time.time()
if not os.path.exists(path):
os.mkdir(path)
if (
os.path.exists(file_prop)
and os.path.exists(file_data)
and os.path.exists(file_level)
and os.path.exists(file_rank_idx)
and os.path.exists(file_rank_all)
and use_cache
):
print("Load from disk cache: " + file)
slhwm = np.load(file_prop)
self.ms_map = slhwm[:, 4:]
self.ms_data = np.load(file_data)
self.ms_start = slhwm[:, 0]
self.ms_len = slhwm[:, 1]
self.ms_level = np.load(file_level)
self.ms_rank_idx = np.load(file_rank_idx)
self.ms_rank_all = np.load(file_rank_all)
self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3]
print("Load end, elapsed:" + str(time.time() - start_time) + "s")
else:
print("Disk cache miss, build new one. size:" + str(size))
map = np.empty((size, max_subitem), dtype=np.int32)
index = np.arange(0, size)
map = np.random.random((size, max_subitem))
mask_zero = map.copy()
mask_zero[:, 0:min_subitem] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
item_sum = map.sum(axis=1)
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
map = (map * scale).astype(np.int32)
map[mask_zero] = -1
map[:vocab_size, 0] = np.arange(0, vocab_size)
map[:vocab_size, 1:] = -1
ms_level = [] # meaning level, vocab's level is 0
ms_rank_idx = [] # meaning index of all level
ms_rank_all = [] # meaning all of all level
ms_start = np.zeros((size), dtype=np.int32) # meaning sequence start
ms_end = np.zeros((size), dtype=np.int32) # meaning sequence end
ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len
ms_height = np.zeros((size), dtype=np.int32) # meaning tree height
ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight
ms_data = np.zeros((datastep), dtype=np.int32) # meaning sequence
ms_level = np.zeros((datastep), dtype=np.uint32) # meaning level, vocab's level is 0
ms_rank_idx = np.zeros((datastep), dtype=np.uint32) # meaning index of all level
ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level
index = 0
for i in range(self.vocab_size):
ms_data[i] = i
ms_start[i] = index
ms_end[i] = index + 1
ms_len[i] = 1
ms_level[i] = 0
ms_rank_idx[i] = 0xFFFFFFF
ms_rank_all[i] = 0xFFFFFFF
ms_height[i] = 0
ms_weight[i] = 1
index = index + 1
for i in range(self.vocab_size, size):
m = map[i]
m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m)
m_list = m.tolist()
assert m_list, "map list can not be empty list"
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
idxidx = np.concatenate(
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
)
len_ma = len(idx)
end = index + len_ma
if ms_data.size < end:
ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
ms_data[index:end] = ms_data[idx]
ms_level[index:end] = ms_level[idx] + 1
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
ms_start[i] = index
ms_end[i] = end
ms_len[i] = len_ma
ms_height[i] = max(ms_height[m_list]) + 1
ms_weight[i] = sum(ms_weight[m_list])
index = index + len_ma
if i % 10000 == 0:
print(i)
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
np.save(file_data, ms_data)
np.save(file_level, ms_level)
np.save(file_rank_idx, ms_rank_idx)
np.save(file_rank_all, ms_rank_all)
ms_start = np.array(ms_start).astype(np.int32)
ms_height = np.array(ms_height).astype(np.int32)
ms_weight = np.array(ms_weight).astype(np.int32)
ms_len = np.array(ms_len).astype(np.int32)
slhwm = np.concatenate(
(
ms_start.reshape((-1, 1)),
ms_len.reshape((-1, 1)),
ms_height.reshape((-1, 1)),
ms_weight.reshape((-1, 1)),
map,
),
axis=1,
)
np.save(file_prop, slhwm)
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
self.ms_level = ms_level
self.ms_rank_idx = ms_rank_idx
self.ms_rank_all = ms_rank_all
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
self.ms_start = ms_start
self.ms_len = ms_len
self.ms_height = ms_height
self.ms_weight = ms_weight
print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
def get_sequence(self, meaning): # return sequence[meaning]
start = self.ms_start[meaning]
len = self.ms_len[meaning]
return (
self.ms_data[start : start + len],
self.ms_level[start : start + len],
self.ms_rank_idx[start : start + len],
self.ms_rank_all[start : start + len],
)
def get_tree(self, meaning): # return meaning all sub items
tree = {}
ms = self.ms_map[meaning]
for m in ms[ms > 0].tolist():
tree[m] = self.get_tree(m) if m >= self.vocab_size else m
return tree
def max_length(self):
return max(self.ms_len)
def get_tree_str(tree, prefix):
if isinstance(tree, list):
base = ""
for t in tree:
base += MeaningMap.get_tree_str(t, "")
return base
else:
if isinstance(tree, dict):
base = ""
last_is_dict = None
for key, value in tree.items():
new_prefix = (len(str(key)) + 2) * " " + prefix
dict_string = MeaningMap.get_tree_str(value, new_prefix)
if dict_string:
base += "\n" + prefix + str(key) + ": " + dict_string
last_is_dict = True
else:
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
last_is_dict = False
return base
return None
def get_tree_indexed_str(tree, data, prefix):
if isinstance(tree, list):
base = ""
qlen = 0
for i, t in enumerate(tree):
s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
base += s
qlen += l
return (base, qlen)
else:
if isinstance(tree, dict):
base = ""
qlen = 0
last_is_dict = None
for key, value in tree.items():
new_prefix = (len(str(key)) + 2) * " " + prefix
dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
if dict_string:
base += "\n" + prefix + str(key) + ": " + dict_string
last_is_dict = True
else:
base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
last_is_dict = False
qlen += l
return (base, qlen)
return (None, 1)
def token_frequency(tree, freq):
if isinstance(tree, dict):
for key, value in tree.items():
if key in freq:
freq[key] = freq[key] + 1
else:
freq[key] = 1
MeaningMap.token_frequency(value, freq)
class MeaningDataset(Dataset):
def __init__(
self,
start,
end,
vocab_size,
size=None,
max_subitem=10,
min_subitem=1,
min_seq_len=2,
seed=42,
use_cache=True,
):
np.random.seed(seed)
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
np.random.seed(seed)
self.mask_level = None
self.mask_idx = None
self.tree = []
self.seq = []
self.level = []
self.rank_idx = []
self.rank_all = []
self.seq_meaning = []
self.m_height = map.ms_height
self.m_weight = map.ms_weight
if size:
meanings = np.random.randint(start, end, size=(size))
else:
meanings = np.arange(start, end)
seq_len = []
for m in meanings:
d, l, i, a = map.get_sequence(m)
if len(d) >= min_seq_len:
self.tree.append({m: map.get_tree(m)})
self.seq.append(d)
self.level.append(l)
self.seq_meaning.append(m)
seq_len.append(len(d))
dm = np.ones(i.shape, dtype=np.uint32)
dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32)
shift = (8 - l) * 4
rank_idx = (i & 0xF) << 28
rank_idx = rank_idx + ((i & 0xF0) << 20)
rank_idx = rank_idx + ((i & 0xF00) << 12)
rank_idx = rank_idx + ((i & 0xF000) << 4)
rank_idx = rank_idx + ((i & 0xF0000) >> 4)
rank_idx = rank_idx + ((i & 0xF00000) >> 12)
rank_idx = rank_idx + ((i & 0xF000000) >> 20)
rank_idx = rank_idx + ((i & 0xF0000000) >> 28)
rank_idx = ((rank_idx >> shift) + dm).astype(np.uint32)
rank_all = (a & 0xF) << 28
rank_all = rank_all + ((a & 0xF0) << 20)
rank_all = rank_all + ((a & 0xF00) << 12)
rank_all = rank_all + ((a & 0xF000) << 4)
rank_all = rank_all + ((a & 0xF0000) >> 4)
rank_all = rank_all + ((a & 0xF00000) >> 12)
rank_all = rank_all + ((a & 0xF000000) >> 20)
rank_all = rank_all + ((a & 0xF0000000) >> 28)
rank_all = ((rank_all >> shift) + dm).astype(np.uint32)
self.rank_idx.append(rank_idx)
self.rank_all.append(rank_all)
unique, counts = np.unique(seq_len, return_counts=True)
print("----------------------------------------------------------------")
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
print("MeaningDataset size:" + str(len(seq_len)))
print("MeaningDataset max sequence length:" + str(max(unique)))
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
print("----------------------------------------------------------------")
def __len__(self):
return len(self.seq)
def len(self):
return len(self.seq)
def set_mask(self, level=None, idx=None):
if self.mask_level is not None and self.mask_idx is not None:
assert len(self.mask_level) > 0, "len must > 0"
assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length"
assert isinstance(self.mask_level, list), "mask level must be list"
assert isinstance(self.mask_idx, list), "mask index must be list"
self.mask_level = level
self.mask_idx = idx
def __getitem__(self, idx):
return self.get_batch([idx])
def get_batch(self, idx_list): # must equal sequence length
data = [self.seq[i] for i in idx_list]
output = {}
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = [self.tree[i] for i in idx_list]
output["mask"] = self.get_seq_mask_tensor(idx_list)
return output
def get_token(self, idx): # must equal sequence length
return self.seq[idx]
def get_tree(self, idx):
return self.tree[idx]
def print_tree(self, idx):
tokens = self.seq[idx]
tree = self.get_tree(idx)
s = str(tokens) + "\n"
s += MeaningMap.get_tree_str(tree, "")
return s
def copy(self, start, end):
new = copy.deepcopy(self)
new.tree = new.tree[start:end]
new.seq = new.seq[start:end]
new.level = new.level[start:end]
new.rank_idx = new.rank_idx[start:end]
new.rank_all = new.rank_all[start:end]
new.seq_meaning = new.seq_meaning[start:end]
new.mask_level = self.mask_level
new.mask_idx = self.mask_idx
return new
def split(self, ratio):
l = self.len()
middle = int(l * ratio)
return self.copy(0, middle), self.copy(middle, l)
def token_frequency(self):
freq = {}
for t in self.tree:
MeaningMap.token_frequency(t, freq)
return freq
def get_seq_mask(self, idx, level, index):
assert index < 15, "index must < 15"
assert level < 8, "level must < 8"
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
return rank_idx == (rank_all + index if index < 0 else index)
def get_seq_mask_tensor(self, idx_list):
if self.mask_level is not None and self.mask_idx is not None:
mask = torch.tensor(
np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0)
)
for i, l in enumerate(self.mask_level[1:]):
mask = mask & torch.tensor(
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0)
)
return mask
else:
return None
class BatchGroupMeaningDataloader(Dataset):
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
seq_len = [len(s) for s in dataset.seq]
unique, counts = np.unique(seq_len, return_counts=True)
gl = {}
for u in unique:
gl[u] = np.where(seq_len == u)[0]
lens = list(gl.keys())
gs = {}
if shuffle:
for k in gl.keys():
sl = gl[k].copy()
np.random.shuffle(sl)
gs[k] = sl
else:
for k in gl.keys():
sl = gl[k].copy()
gs[k] = sl
index = np.zeros((0, batch_size), dtype=np.int64)
for l in lens:
batch = len(gs[l]) // batch_size
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
index = np.concatenate((index, new), axis=0)
if shuffle:
index_shuffle = np.arange(0, index.shape[0])
np.random.shuffle(index_shuffle)
index = index[index_shuffle]
self.indexBatch = index
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size))
def __len__(self):
return len(self.indexBatch)
def __getitem__(self, idx):
return self.dataset.get_batch(self.indexBatch[idx])
def get_tree(self, idx):
return [self.dataset.get_tree(i) for i in self.indexBatch[idx]]
def print_tree(self, idx):
idx_list = self.indexBatch[idx]
s = "--------------------------------------------------------\n"
for i in idx_list:
s += self.dataset.print_tree(i)
s += "--------------------------------------------------------\n"
return s
def detection_collate(batch):
return batch[0]
def dataloader(self, num_workers=1):
return DataLoader(
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
)
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
md.set_mask([1], [-1])
train, val = md.split(0.95)
fdaf = md.__getitem__(920)
print(md.print_tree(920))
print(md.rank_idx[920])
print(md.rank_all[920])
mask = md.get_seq_mask(920, 0, -1)
print(mask)
mask = md.get_seq_mask(920, 1, 0)
print(mask)
mask = md.get_seq_mask(920, 1, -1)
print(mask)
mask = md.get_seq_mask(920, 1, 1)
print(mask)
assert all(
np.equal(
mask[0:57],
np.array(
[
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
True,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
]
),
)
), "False"
freq = md.token_frequency()
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
md.set_mask([0, 1], [0, -1])
dl = BatchGroupMeaningDataloader(md, 1)
length = len(dl)
it = iter(dl)
ne1 = next(it)
tree = ne1["tree"]
mask = ne1["mask"].cpu().numpy()
t = MeaningMap.get_tree_str(tree, "")
print(t)
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
print(m)
ne2 = next(it)
ne3 = next(it)
# map1 = dl.get_tree(0)
# map2 = dl.get_tree(1)
# print(dl.print_tree(0))
# dl = DataLoader(
# train,
# num_workers=1,
# persistent_workers=True,
# shuffle=False,
# )
# it = iter(dl)
# ne1 = next(it)
# ne2 = next(it)
# ne3 = next(it)
# for i in range(10):
# print(next(it)["input_ids"].numpy().tolist())