Witllm/wit/dataset/meaning_dataset.py

526 lines
19 KiB
Python

import os
import torch, datasets
import math, gc, time, random, copy
from itertools import chain
from typing import Dict, Tuple
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
import numpy as np
from torch.utils.data import BatchSampler
from dataset.node_tree import NodeTree
# import warnings
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
class MeaningMap:
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42):
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
assert min_subitem <= max_subitem, "Invalid input"
np.random.seed(seed)
self.size = size
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
datastep = 0x8000000
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size)
file += "_" + str(max_subitem) + "_" + str(min_subitem)
file_prop = path + file + "_prop.npy"
file_data = path + file + "_data.npy"
file_level = path + file + "_level.npy"
file_rank_idx = path + file + "_rank_idx.npy"
file_rank_all = path + file + "_rank_all.npy"
start_time = time.time()
if not os.path.exists(path):
os.mkdir(path)
if (
os.path.exists(file_prop)
and os.path.exists(file_data)
and os.path.exists(file_level)
and os.path.exists(file_rank_idx)
and os.path.exists(file_rank_all)
and use_cache
):
print("Mapping Load from disk cache: " + file)
slhwm = np.load(file_prop)
self.ms_map = slhwm[:, 4:]
self.ms_data = np.load(file_data)
self.ms_start = slhwm[:, 0]
self.ms_len = slhwm[:, 1]
self.ms_level = np.load(file_level)
self.ms_rank_idx = np.load(file_rank_idx)
self.ms_rank_all = np.load(file_rank_all)
self.ms_height = slhwm[:, 2]
self.ms_weight = slhwm[:, 3]
print("Mapping Load end, elapsed:" + str(time.time() - start_time) + "s")
else:
print("Mapping Disk cache miss, build new one. size:" + str(size))
map = np.empty((size, max_subitem), dtype=np.int32)
index = np.arange(0, size)
map = np.random.random((size, max_subitem))
mask_zero = map.copy()
mask_zero[:, 0:min_subitem] = 0.0
mask_zero.sort(axis=1)
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
mask_zero = mask_zero > thre
item_sum = map.sum(axis=1)
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
map = (map * scale).astype(np.int32)
map[mask_zero] = -1
map[:vocab_size, 0] = np.arange(0, vocab_size)
map[:vocab_size, 1:] = -1
ms_level = [] # meaning level, vocab's level is 0
ms_rank_idx = [] # meaning index of all level
ms_rank_all = [] # meaning all of all level
ms_start = np.zeros((size), dtype=np.int32) # meaning sequence start
ms_end = np.zeros((size), dtype=np.int32) # meaning sequence end
ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len
ms_height = np.zeros((size), dtype=np.int32) # meaning tree height
ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight
ms_data = np.zeros((datastep), dtype=np.int32) # meaning sequence
ms_level = np.zeros((datastep), dtype=np.uint32) # meaning level, vocab's level is 0
ms_rank_idx = np.zeros((datastep), dtype=np.uint32) # meaning index of all level
ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level
index = 0
for i in range(self.vocab_size):
ms_data[i] = i
ms_start[i] = index
ms_end[i] = index + 1
ms_len[i] = 1
ms_level[i] = 0
ms_rank_idx[i] = 0xFFFFFFF
ms_rank_all[i] = 0xFFFFFFF
ms_height[i] = 0
ms_weight[i] = 1
index = index + 1
for i in range(self.vocab_size, size):
m = map[i]
m = m[m >= 0] # donot cut off the map such as [0]
m_len = len(m)
m_list = m.tolist()
assert m_list, "map list can not be empty list"
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
idxidx = np.concatenate(
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
)
len_ma = len(idx)
end = index + len_ma
if ms_data.size < end:
ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
ms_data[index:end] = ms_data[idx]
ms_level[index:end] = ms_level[idx] + 1
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32)
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32)
ms_start[i] = index
ms_end[i] = end
ms_len[i] = len_ma
ms_height[i] = max(ms_height[m_list]) + 1
ms_weight[i] = sum(ms_weight[m_list])
index = index + len_ma
if i % 10000 == 0:
print(i)
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
np.save(file_data, ms_data)
np.save(file_level, ms_level)
np.save(file_rank_idx, ms_rank_idx)
np.save(file_rank_all, ms_rank_all)
ms_start = np.array(ms_start).astype(np.int32)
ms_height = np.array(ms_height).astype(np.int32)
ms_weight = np.array(ms_weight).astype(np.int32)
ms_len = np.array(ms_len).astype(np.int32)
slhwm = np.concatenate(
(
ms_start.reshape((-1, 1)),
ms_len.reshape((-1, 1)),
ms_height.reshape((-1, 1)),
ms_weight.reshape((-1, 1)),
map,
),
axis=1,
)
np.save(file_prop, slhwm)
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
self.ms_level = ms_level
self.ms_rank_idx = ms_rank_idx
self.ms_rank_all = ms_rank_all
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
self.ms_start = ms_start
self.ms_len = ms_len
self.ms_height = ms_height
self.ms_weight = ms_weight
print("Mapping Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
def get_sequence(self, meaning): # return sequence[meaning]
start = self.ms_start[meaning]
len = self.ms_len[meaning]
return (
self.ms_data[start : start + len],
self.ms_level[start : start + len],
self.ms_rank_idx[start : start + len],
self.ms_rank_all[start : start + len],
)
def get_tree(self, meaning): # return meaning all sub items
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
ms = ms_map[meaning]
for m in ms[ms >= 0].tolist():
if m >= self.vocab_size:
pn = NodeTree(str(m), parent)
get_tree_node(ms_map, m, vocab_size, pn, seqlist)
else:
pn = NodeTree("<" + str(m) + ">", parent)
seqlist.append(pn)
root = NodeTree(str(meaning))
seqlist = []
get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist)
root.seq_node = seqlist
return root
def max_length(self):
return max(self.ms_len)
class MeaningDataset(Dataset):
def __init__(
self,
start,
end,
vocab_size,
size=None,
max_subitem=10,
min_subitem=1,
min_seq_len=2,
seed=42,
use_cache=True,
):
np.random.seed(seed)
self.start = start
self.end = end
self.vocab_size = vocab_size
self.max_subitem = max_subitem
self.min_subitem = min_subitem
self.use_cache = use_cache
self.min_seq_len = min_seq_len
print("Build MeaningDataset from MeaningMap.")
self.val_mask_level = None
self.val_mask_idx = None
self.seq = []
self.level = []
self.rank_idx = []
self.rank_all = []
self.seq_meaning = []
map = self.get_meaning_map()
self.m_height = map.ms_height
self.m_weight = map.ms_weight
if size:
meanings = np.random.randint(start, end, size=(size))
else:
meanings = np.arange(start, end)
seq_len = []
for m in meanings:
d, l, i, a = map.get_sequence(m)
if len(d) >= min_seq_len:
self.seq.append(d)
self.level.append(l)
self.seq_meaning.append(m)
seq_len.append(len(d))
dm = np.ones(i.shape, dtype=np.uint32)
dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32)
shift = (8 - l) * 4
rank_idx = (i & 0xF) << 28
rank_idx = rank_idx + ((i & 0xF0) << 20)
rank_idx = rank_idx + ((i & 0xF00) << 12)
rank_idx = rank_idx + ((i & 0xF000) << 4)
rank_idx = rank_idx + ((i & 0xF0000) >> 4)
rank_idx = rank_idx + ((i & 0xF00000) >> 12)
rank_idx = rank_idx + ((i & 0xF000000) >> 20)
rank_idx = rank_idx + ((i & 0xF0000000) >> 28)
rank_idx = ((rank_idx >> shift) + dm).astype(np.uint32)
rank_all = (a & 0xF) << 28
rank_all = rank_all + ((a & 0xF0) << 20)
rank_all = rank_all + ((a & 0xF00) << 12)
rank_all = rank_all + ((a & 0xF000) << 4)
rank_all = rank_all + ((a & 0xF0000) >> 4)
rank_all = rank_all + ((a & 0xF00000) >> 12)
rank_all = rank_all + ((a & 0xF000000) >> 20)
rank_all = rank_all + ((a & 0xF0000000) >> 28)
rank_all = ((rank_all >> shift) + dm).astype(np.uint32)
self.rank_idx.append(rank_idx)
self.rank_all.append(rank_all)
unique, counts = np.unique(seq_len, return_counts=True)
print("Build MeaningDataset end.")
print("----------------------------------------------------------------")
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
print("MeaningDataset size:" + str(len(seq_len)))
print("MeaningDataset max sequence length:" + str(max(unique)))
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
print("----------------------------------------------------------------")
def __len__(self):
return len(self.seq)
def len(self):
return len(self.seq)
def get_meaning_map(self):
return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache)
def set_mask(self, level=None, idx=None):
if self.val_mask_level is not None and self.val_mask_idx is not None:
assert len(self.val_mask_level) > 0, "len must > 0"
assert len(self.val_mask_level) == len(self.val_mask_idx), "mask level and mask index must be same length"
assert isinstance(self.val_mask_level, list), "mask level must be list"
assert isinstance(self.val_mask_idx, list), "mask index must be list"
self.val_mask_level = level
self.val_mask_idx = idx
def __getitem__(self, idx):
return self.get_batch([idx])
def get_batch(self, idx_list): # must equal sequence length
data = [self.seq[i] for i in idx_list]
output = {}
data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
return output
def get_token(self, idx): # must equal sequence length
return self.seq[idx]
def get_meaning(self, idx):
return self.seq_meaning[idx]
def copy(self, start, end):
new = copy.deepcopy(self)
new.seq = new.seq[start:end]
new.level = new.level[start:end]
new.rank_idx = new.rank_idx[start:end]
new.rank_all = new.rank_all[start:end]
new.seq_meaning = new.seq_meaning[start:end]
new.val_mask_level = self.val_mask_level
new.val_mask_idx = self.val_mask_idx
return new
def split(self, ratio):
l = self.len()
middle = int(l * ratio)
return self.copy(0, middle), self.copy(middle, l)
def get_seq_mask(self, idx, level, index):
# assert index < 15, "index must < 15"
# assert level < 8, "level must < 8"
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
return rank_idx == (rank_all + index if index < 0 else index)
def get_seq_mask_tensor(self, idx_list):
if self.val_mask_level is not None and self.val_mask_idx is not None:
mask = torch.tensor(
np.stack(
[self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0
)
)
for i, l in enumerate(self.val_mask_level[1:]):
mask = mask & torch.tensor(
np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0)
)
return mask
else:
return None
class BatchGroupMeaningDataloader(Dataset):
def __init__(self, meaning_dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
self.meaning_dataset = meaning_dataset
self.batch_size = batch_size
self.drop_last = drop_last
seq_len = [len(s) for s in meaning_dataset.seq]
unique, counts = np.unique(seq_len, return_counts=True)
gl = {}
for u in unique:
gl[u] = np.where(seq_len == u)[0]
lens = list(gl.keys())
gs = {}
if shuffle:
for k in gl.keys():
sl = gl[k].copy()
np.random.shuffle(sl)
gs[k] = sl
else:
for k in gl.keys():
sl = gl[k].copy()
gs[k] = sl
index = np.zeros((0, batch_size), dtype=np.int64)
for l in lens:
batch = len(gs[l]) // batch_size
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
index = np.concatenate((index, new), axis=0)
if shuffle:
index_shuffle = np.arange(0, index.shape[0])
np.random.shuffle(index_shuffle)
index = index[index_shuffle]
self.indexBatch = index
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size))
def __len__(self):
return len(self.indexBatch)
def __getitem__(self, idx):
return self.meaning_dataset.get_batch(self.indexBatch[idx])
def get_tree(self, idx):
return [self.meaning_dataset.get_tree(i) for i in self.indexBatch[idx]]
def print_tree(self, idx):
idx_list = self.indexBatch[idx]
s = "--------------------------------------------------------\n"
for i in idx_list:
s += self.meaning_dataset.print_tree(i)
s += "--------------------------------------------------------\n"
return s
def detection_collate(batch):
return batch[0]
def dataloader(self, num_workers=1):
return DataLoader(
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
)
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
md.set_mask([1], [-1])
train, val = md.split(0.95)
fdaf = md.__getitem__(920)
print(md.print_tree(920))
print(md.rank_idx[920])
print(md.rank_all[920])
mask = md.get_seq_mask(920, 0, -1)
print(mask)
mask = md.get_seq_mask(920, 1, 0)
print(mask)
mask = md.get_seq_mask(920, 1, -1)
print(mask)
mask = md.get_seq_mask(920, 1, 1)
print(mask)
assert all(
np.equal(
mask[0:57],
np.array(
[
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
True,
False,
False,
False,
False,
False,
False,
False,
True,
True,
True,
True,
False,
False,
False,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
False,
False,
False,
False,
False,
False,
]
),
)
), "False"
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
md.set_mask([0, 1], [0, -1])
dl = BatchGroupMeaningDataloader(md, 1)
length = len(dl)
it = iter(dl)
ne1 = next(it)
tree = ne1["tree"]
mask = ne1["mask"].cpu().numpy()
t = MeaningMap.print_tree(tree)
print(t)
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
print(m)
ne2 = next(it)
ne3 = next(it)