559 lines
21 KiB
Python
559 lines
21 KiB
Python
import os
|
||
import torch, datasets
|
||
import math, gc, time, random, copy
|
||
from itertools import chain
|
||
from typing import Dict, Tuple
|
||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
||
import numpy as np
|
||
from torch.utils.data import BatchSampler
|
||
from dataset.node_tree import NodeTree
|
||
|
||
# import warnings
|
||
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
||
|
||
|
||
class MeaningMap:
|
||
def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42):
|
||
assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input"
|
||
assert min_subitem <= max_subitem, "Invalid input"
|
||
np.random.seed(seed)
|
||
self.size = size
|
||
self.vocab_size = vocab_size
|
||
self.max_subitem = max_subitem
|
||
self.min_subitem = min_subitem
|
||
datastep = 0x8000000
|
||
|
||
path = "./data/"
|
||
file = "structured_language_" + str(size) + "_" + str(vocab_size)
|
||
file += "_" + str(max_subitem) + "_" + str(min_subitem)
|
||
file_prop = path + file + "_prop.npy"
|
||
file_data = path + file + "_data.npy"
|
||
file_level = path + file + "_level.npy"
|
||
file_rank_idx = path + file + "_rank_idx.npy"
|
||
file_rank_all = path + file + "_rank_all.npy"
|
||
|
||
start_time = time.time()
|
||
if not os.path.exists(path):
|
||
os.mkdir(path)
|
||
if (
|
||
os.path.exists(file_prop)
|
||
and os.path.exists(file_data)
|
||
and os.path.exists(file_level)
|
||
and os.path.exists(file_rank_idx)
|
||
and os.path.exists(file_rank_all)
|
||
and use_cache
|
||
):
|
||
print("Mapping Load from disk cache: " + file)
|
||
slhwm = np.load(file_prop)
|
||
self.ms_map = slhwm[:, 4:]
|
||
self.ms_data = np.load(file_data)
|
||
self.ms_start = slhwm[:, 0]
|
||
self.ms_len = slhwm[:, 1]
|
||
self.ms_level = np.load(file_level)
|
||
self.ms_rank_idx = np.load(file_rank_idx)
|
||
self.ms_rank_all = np.load(file_rank_all)
|
||
self.ms_height = slhwm[:, 2]
|
||
self.ms_weight = slhwm[:, 3]
|
||
print("Mapping Load end, elapsed:" + str(time.time() - start_time) + "s")
|
||
else:
|
||
print("Mapping Disk cache miss, build new one. size:" + str(size))
|
||
|
||
map = np.empty((size, max_subitem), dtype=np.int32)
|
||
|
||
index = np.arange(0, size)
|
||
map = np.random.random((size, max_subitem))
|
||
|
||
mask_zero = map.copy()
|
||
mask_zero[:, 0:min_subitem] = 0.0
|
||
mask_zero.sort(axis=1)
|
||
thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
|
||
mask_zero = mask_zero > thre
|
||
|
||
item_sum = map.sum(axis=1)
|
||
scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
|
||
map = (map * scale).astype(np.int32)
|
||
|
||
map[mask_zero] = -1
|
||
|
||
map[:vocab_size, 0] = np.arange(0, vocab_size)
|
||
map[:vocab_size, 1:] = -1
|
||
|
||
ms_level = [] # meaning level, vocab's level is 0
|
||
ms_rank_idx = [] # meaning index of all level
|
||
ms_rank_all = [] # meaning all of all level
|
||
ms_start = np.zeros((size), dtype=np.int32) # meaning sequence start
|
||
ms_end = np.zeros((size), dtype=np.int32) # meaning sequence end
|
||
ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len
|
||
ms_height = np.zeros((size), dtype=np.int32) # meaning tree height
|
||
ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight
|
||
ms_data = np.zeros((datastep), dtype=np.int32) # meaning sequence
|
||
ms_level = np.zeros((datastep), dtype=np.uint32) # meaning level, vocab's level is 0
|
||
ms_rank_idx = np.zeros((datastep), dtype=np.uint32) # meaning index of all level
|
||
ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level
|
||
|
||
index = 0
|
||
for i in range(self.vocab_size):
|
||
ms_data[i] = i
|
||
ms_start[i] = index
|
||
ms_end[i] = index + 1
|
||
ms_len[i] = 1
|
||
ms_level[i] = 0
|
||
ms_rank_idx[i] = 0xFFFFFFF
|
||
ms_rank_all[i] = 0xFFFFFFF
|
||
ms_height[i] = 0
|
||
ms_weight[i] = 1
|
||
index = index + 1
|
||
|
||
for i in range(self.vocab_size, size):
|
||
m = map[i] # 当前meaning的拆分的分支
|
||
m = m[m >= 0] # donot cut off the map such as [0]
|
||
m_len = len(m) # 当前meaning的拆分的分支个数
|
||
m_list = m.tolist()
|
||
assert m_list, "map list can not be empty list"
|
||
|
||
# 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(<vocab_size)
|
||
idx = np.concatenate([np.arange(ms_start[m], ms_end[m]) for m in m_list])
|
||
idxidx = np.concatenate(
|
||
[np.ones(l, dtype=np.uint32) * i for i, l in enumerate(ms_end[m_list] - ms_start[m_list])]
|
||
)
|
||
len_ma = len(idx)
|
||
|
||
end = index + len_ma
|
||
if ms_data.size < end: # 超过存储数据结构的大小,扩展一个datastep容量
|
||
ms_data = np.concatenate([ms_data, np.zeros((datastep), dtype=np.int32)])
|
||
ms_level = np.concatenate([ms_level, np.zeros((datastep), dtype=np.uint32)])
|
||
ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((datastep), dtype=np.uint32)])
|
||
ms_rank_all = np.concatenate([ms_rank_all, np.zeros((datastep), dtype=np.uint32)])
|
||
|
||
ms_data[index:end] = ms_data[idx] # 拼接当前meaning的所有token到data数据结构里面
|
||
ms_level[index:end] = ms_level[idx] + 1 # 处理level
|
||
ms_rank_idx[index:end] = (ms_rank_idx[idx] * 16 + idxidx).astype(np.uint32) # 处理rank_idx
|
||
ms_rank_all[index:end] = (ms_rank_all[idx] * 16 + m_len).astype(np.uint32) # 处理rank_all
|
||
|
||
ms_start[i] = index
|
||
ms_end[i] = end
|
||
ms_len[i] = len_ma
|
||
ms_height[i] = max(ms_height[m_list]) + 1
|
||
ms_weight[i] = sum(ms_weight[m_list])
|
||
index = index + len_ma
|
||
if i % 10000 == 0:
|
||
print(i)
|
||
|
||
print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
|
||
|
||
np.save(file_data, ms_data)
|
||
np.save(file_level, ms_level)
|
||
np.save(file_rank_idx, ms_rank_idx)
|
||
np.save(file_rank_all, ms_rank_all)
|
||
|
||
ms_start = np.array(ms_start).astype(np.int32)
|
||
ms_height = np.array(ms_height).astype(np.int32)
|
||
ms_weight = np.array(ms_weight).astype(np.int32)
|
||
ms_len = np.array(ms_len).astype(np.int32)
|
||
slhwm = np.concatenate(
|
||
(
|
||
ms_start.reshape((-1, 1)),
|
||
ms_len.reshape((-1, 1)),
|
||
ms_height.reshape((-1, 1)),
|
||
ms_weight.reshape((-1, 1)),
|
||
map,
|
||
),
|
||
axis=1,
|
||
)
|
||
np.save(file_prop, slhwm)
|
||
|
||
self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]]
|
||
self.ms_level = ms_level
|
||
self.ms_rank_idx = ms_rank_idx
|
||
self.ms_rank_all = ms_rank_all
|
||
|
||
self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)]
|
||
self.ms_start = ms_start
|
||
self.ms_len = ms_len
|
||
self.ms_height = ms_height
|
||
self.ms_weight = ms_weight
|
||
print("Mapping Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
|
||
|
||
def get_sequence(self, meaning): # return sequence[meaning]
|
||
start = self.ms_start[meaning]
|
||
len = self.ms_len[meaning]
|
||
return (
|
||
self.ms_data[start : start + len],
|
||
self.ms_level[start : start + len],
|
||
self.ms_rank_idx[start : start + len],
|
||
self.ms_rank_all[start : start + len],
|
||
)
|
||
|
||
def get_nodetree(self, meaning): # return meaning all sub items
|
||
def get_tree_node(ms_map, meaning, vocab_size, parent, seqlist):
|
||
ms = ms_map[meaning]
|
||
for m in ms[ms >= 0].tolist():
|
||
if m >= self.vocab_size:
|
||
pn = NodeTree(str(m), parent)
|
||
get_tree_node(ms_map, m, vocab_size, pn, seqlist)
|
||
else:
|
||
pn = NodeTree("<" + str(m) + ">", parent)
|
||
seqlist.append(pn)
|
||
|
||
root = NodeTree(str(meaning))
|
||
seqlist = []
|
||
get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist)
|
||
root.seq_node = seqlist
|
||
return root
|
||
|
||
# 返回每个token相对于上一个token的level变化
|
||
# 返回两个list,分别表示 common -> current -> common 两个变化的level距离
|
||
def get_level_change(self, meaning):
|
||
def level_change(ms_map, meaning, current_to_common, common_to_current):
|
||
ms = ms_map[meaning]
|
||
for m in ms[ms >= 0].tolist():
|
||
if m >= self.vocab_size:
|
||
common_to_current[-1] = common_to_current[-1] + 1
|
||
level_change(ms_map, m, current_to_common, common_to_current)
|
||
else:
|
||
current_to_common.append(0)
|
||
common_to_current.append(0)
|
||
current_to_common[-2] = current_to_common[-2] + 1
|
||
|
||
common_to_current = []
|
||
common_to_current.append(1)
|
||
current_to_common = []
|
||
current_to_common.append(0)
|
||
level_change(self.ms_map, meaning, current_to_common, common_to_current)
|
||
current_to_common = current_to_common[:-1]
|
||
common_to_current = common_to_current[:-1]
|
||
return current_to_common, common_to_current
|
||
|
||
# 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面token是否有关系
|
||
def get_relation_table(self, meaning):
|
||
current_to_common, common_to_current = self.get_level_change(meaning)
|
||
width = len(current_to_common)
|
||
relation = np.zeros((width, width), dtype=int)
|
||
relation[0, 0] = 1
|
||
for i in range(1, width, 1):
|
||
if i == width - 2:
|
||
print(1)
|
||
ori = current_to_common[i] - common_to_current[i]
|
||
start_index = width
|
||
for s in range(i - 1, -1, -1):
|
||
if ori < 0:
|
||
break
|
||
ori = ori - common_to_current[s] + current_to_common[s]
|
||
start_index = s
|
||
relation[i, start_index : i + 1] = 1.0
|
||
return relation
|
||
|
||
def max_length(self):
|
||
return max(self.ms_len)
|
||
|
||
|
||
class MeaningDataset(Dataset):
|
||
|
||
def __init__(
|
||
self,
|
||
start,
|
||
end,
|
||
vocab_size,
|
||
size=None,
|
||
max_subitem=10,
|
||
min_subitem=1,
|
||
min_seq_len=2,
|
||
seed=42,
|
||
use_cache=True,
|
||
):
|
||
np.random.seed(seed)
|
||
self.start = start
|
||
self.end = end
|
||
self.vocab_size = vocab_size
|
||
self.max_subitem = max_subitem
|
||
self.min_subitem = min_subitem
|
||
self.use_cache = use_cache
|
||
self.min_seq_len = min_seq_len
|
||
print("Build MeaningDataset from MeaningMap.")
|
||
self.val_mask_level = None
|
||
self.val_mask_idx = None
|
||
self.seq = []
|
||
self.level = []
|
||
self.rank_idx = []
|
||
self.rank_all = []
|
||
self.seq_meaning = []
|
||
map = self.get_meaning_map()
|
||
self.m_height = map.ms_height
|
||
self.m_weight = map.ms_weight
|
||
if size:
|
||
meanings = np.random.randint(start, end, size=(size))
|
||
else:
|
||
meanings = np.arange(start, end)
|
||
|
||
seq_len = []
|
||
for m in meanings:
|
||
d, l, i, a = map.get_sequence(m)
|
||
if len(d) >= min_seq_len:
|
||
self.seq.append(d)
|
||
self.level.append(l)
|
||
self.seq_meaning.append(m)
|
||
seq_len.append(len(d))
|
||
|
||
dm = np.ones(i.shape, dtype=np.uint32)
|
||
dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32)
|
||
shift = (8 - l) * 4
|
||
rank_idx = (i & 0xF) << 28
|
||
rank_idx = rank_idx + ((i & 0xF0) << 20)
|
||
rank_idx = rank_idx + ((i & 0xF00) << 12)
|
||
rank_idx = rank_idx + ((i & 0xF000) << 4)
|
||
rank_idx = rank_idx + ((i & 0xF0000) >> 4)
|
||
rank_idx = rank_idx + ((i & 0xF00000) >> 12)
|
||
rank_idx = rank_idx + ((i & 0xF000000) >> 20)
|
||
rank_idx = rank_idx + ((i & 0xF0000000) >> 28)
|
||
rank_idx = ((rank_idx >> shift) + dm).astype(np.uint32)
|
||
rank_all = (a & 0xF) << 28
|
||
rank_all = rank_all + ((a & 0xF0) << 20)
|
||
rank_all = rank_all + ((a & 0xF00) << 12)
|
||
rank_all = rank_all + ((a & 0xF000) << 4)
|
||
rank_all = rank_all + ((a & 0xF0000) >> 4)
|
||
rank_all = rank_all + ((a & 0xF00000) >> 12)
|
||
rank_all = rank_all + ((a & 0xF000000) >> 20)
|
||
rank_all = rank_all + ((a & 0xF0000000) >> 28)
|
||
rank_all = ((rank_all >> shift) + dm).astype(np.uint32)
|
||
|
||
self.rank_idx.append(rank_idx)
|
||
self.rank_all.append(rank_all)
|
||
|
||
unique, counts = np.unique(seq_len, return_counts=True)
|
||
print("Build MeaningDataset end.")
|
||
print("----------------------------------------------------------------")
|
||
print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start))
|
||
print("MeaningDataset size:" + str(len(seq_len)))
|
||
print("MeaningDataset max sequence length:" + str(max(unique)))
|
||
print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)]))
|
||
print("----------------------------------------------------------------")
|
||
|
||
def __len__(self):
|
||
return len(self.seq)
|
||
|
||
def len(self):
|
||
return len(self.seq)
|
||
|
||
def get_meaning_map(self):
|
||
return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache)
|
||
|
||
def set_mask(self, level=None, idx=None):
|
||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
||
assert len(self.val_mask_level) > 0, "len must > 0"
|
||
assert len(self.val_mask_level) == len(self.val_mask_idx), "mask level and mask index must be same length"
|
||
assert isinstance(self.val_mask_level, list), "mask level must be list"
|
||
assert isinstance(self.val_mask_idx, list), "mask index must be list"
|
||
self.val_mask_level = level
|
||
self.val_mask_idx = idx
|
||
|
||
def __getitem__(self, idx):
|
||
return self.get_batch([idx])
|
||
|
||
def get_batch(self, idx_list): # must equal sequence length
|
||
data = [self.seq[i] for i in idx_list]
|
||
output = {}
|
||
data = torch.tensor(np.stack(data, axis=0)).long()
|
||
output["input_ids"] = data
|
||
output["labels"] = data.clone()
|
||
output["token_type_ids"] = torch.zeros(data.shape)
|
||
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
|
||
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
|
||
return output
|
||
|
||
def get_token(self, idx): # must equal sequence length
|
||
return self.seq[idx]
|
||
|
||
def get_meaning(self, idx):
|
||
return self.seq_meaning[idx]
|
||
|
||
def copy(self, start, end):
|
||
new = copy.deepcopy(self)
|
||
new.seq = new.seq[start:end]
|
||
new.level = new.level[start:end]
|
||
new.rank_idx = new.rank_idx[start:end]
|
||
new.rank_all = new.rank_all[start:end]
|
||
new.seq_meaning = new.seq_meaning[start:end]
|
||
new.val_mask_level = self.val_mask_level
|
||
new.val_mask_idx = self.val_mask_idx
|
||
return new
|
||
|
||
def split(self, ratio):
|
||
l = self.len()
|
||
middle = int(l * ratio)
|
||
return self.copy(0, middle), self.copy(middle, l)
|
||
|
||
def get_seq_mask(self, idx, level, index):
|
||
# assert index < 15, "index must < 15"
|
||
# assert level < 8, "level must < 8"
|
||
rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||
return rank_idx == (rank_all + index if index < 0 else index)
|
||
|
||
def get_seq_mask_tensor(self, idx_list):
|
||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
||
mask = torch.tensor(
|
||
np.stack(
|
||
[self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0
|
||
)
|
||
)
|
||
for i, l in enumerate(self.val_mask_level[1:]):
|
||
mask = mask & torch.tensor(
|
||
np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0)
|
||
)
|
||
return mask
|
||
else:
|
||
return None
|
||
|
||
|
||
class BatchGroupMeaningDataloader(Dataset):
|
||
def __init__(self, meaning_dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
||
self.meaning_dataset = meaning_dataset
|
||
self.batch_size = batch_size
|
||
self.drop_last = drop_last
|
||
|
||
seq_len = [len(s) for s in meaning_dataset.seq]
|
||
unique, counts = np.unique(seq_len, return_counts=True)
|
||
gl = {}
|
||
for u in unique:
|
||
gl[u] = np.where(seq_len == u)[0]
|
||
|
||
lens = list(gl.keys())
|
||
gs = {}
|
||
if shuffle:
|
||
for k in gl.keys():
|
||
sl = gl[k].copy()
|
||
np.random.shuffle(sl)
|
||
gs[k] = sl
|
||
else:
|
||
for k in gl.keys():
|
||
sl = gl[k].copy()
|
||
gs[k] = sl
|
||
|
||
index = np.zeros((0, batch_size), dtype=np.int64)
|
||
for l in lens:
|
||
batch = len(gs[l]) // batch_size
|
||
new = gs[l][0 : batch * batch_size].reshape(batch, batch_size)
|
||
index = np.concatenate((index, new), axis=0)
|
||
|
||
if shuffle:
|
||
index_shuffle = np.arange(0, index.shape[0])
|
||
np.random.shuffle(index_shuffle)
|
||
index = index[index_shuffle]
|
||
self.indexBatch = index
|
||
print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index)))
|
||
print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size))
|
||
|
||
def __len__(self):
|
||
return len(self.indexBatch)
|
||
|
||
def __getitem__(self, idx):
|
||
return self.meaning_dataset.get_batch(self.indexBatch[idx])
|
||
|
||
def detection_collate(batch):
|
||
return batch[0]
|
||
|
||
def dataloader(self, num_workers=1):
|
||
return DataLoader(
|
||
self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
|
||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
||
md.set_mask([1], [-1])
|
||
train, val = md.split(0.95)
|
||
fdaf = md.__getitem__(920)
|
||
print(md.print_tree(920))
|
||
print(md.rank_idx[920])
|
||
print(md.rank_all[920])
|
||
mask = md.get_seq_mask(920, 0, -1)
|
||
print(mask)
|
||
mask = md.get_seq_mask(920, 1, 0)
|
||
print(mask)
|
||
mask = md.get_seq_mask(920, 1, -1)
|
||
print(mask)
|
||
mask = md.get_seq_mask(920, 1, 1)
|
||
print(mask)
|
||
assert all(
|
||
np.equal(
|
||
mask[0:57],
|
||
np.array(
|
||
[
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
True,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
True,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
False,
|
||
]
|
||
),
|
||
)
|
||
), "False"
|
||
|
||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False)
|
||
md.set_mask([0, 1], [0, -1])
|
||
dl = BatchGroupMeaningDataloader(md, 1)
|
||
length = len(dl)
|
||
it = iter(dl)
|
||
ne1 = next(it)
|
||
tree = ne1["tree"]
|
||
mask = ne1["mask"].cpu().numpy()
|
||
t = MeaningMap.print_tree(tree)
|
||
print(t)
|
||
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
|
||
print(m)
|
||
ne2 = next(it)
|
||
ne3 = next(it)
|