import os import torch, datasets import math, gc, time, random, copy from itertools import chain from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler from dataset.node_tree import NodeTree # import warnings # warnings.filterwarnings("ignore", ".*does not have many workers.*") class MeaningMap: def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True, seed=42): assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input" assert min_subitem <= max_subitem, "Invalid input" np.random.seed(seed) self.size = size self.vocab_size = vocab_size self.max_subitem = max_subitem self.min_subitem = min_subitem datastep = 0x8000000 path = "./data/" file = "structured_language_" + str(size) + "_" + str(vocab_size) file += "_" + str(max_subitem) + "_" + str(min_subitem) file_prop = path + file + "_prop.npy" file_data = path + file + "_data.npy" file_level = path + file + "_level.npy" file_rank_idx = path + file + "_rank_idx.npy" file_rank_all = path + file + "_rank_all.npy" start_time = time.time() if not os.path.exists(path): os.mkdir(path) if ( os.path.exists(file_prop) and os.path.exists(file_data) and os.path.exists(file_level) and os.path.exists(file_rank_idx) and os.path.exists(file_rank_all) and use_cache ): print("Mapping Load from disk cache: " + file) slhwm = np.load(file_prop) self.ms_map = slhwm[:, 4:] self.ms_data = np.load(file_data) self.ms_start = slhwm[:, 0] self.ms_len = slhwm[:, 1] self.ms_level = np.load(file_level) self.ms_rank_idx = np.load(file_rank_idx) self.ms_rank_all = np.load(file_rank_all) self.ms_height = slhwm[:, 2] self.ms_weight = slhwm[:, 3] print("Mapping Load end, elapsed:" + str(time.time() - start_time) + "s") else: print("Mapping Disk cache miss, build new one. size:" + str(size)) map = np.empty((size, max_subitem), dtype=np.int32) index = np.arange(0, size) map = np.random.random((size, max_subitem)) mask_zero = map.copy() mask_zero[:, 0:min_subitem] = 0.0 mask_zero.sort(axis=1) thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) mask_zero = mask_zero > thre item_sum = map.sum(axis=1) scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) map = (map * scale).astype(np.int32) map[mask_zero] = -1 map[:vocab_size, 0] = np.arange(0, vocab_size) map[:vocab_size, 1:] = -1 ms_level = [] # meaning level, vocab's level is 0 ms_rank_idx = [] # meaning index of all level ms_rank_all = [] # meaning all of all level ms_start = np.zeros((size), dtype=np.int32) # meaning sequence start ms_end = np.zeros((size), dtype=np.int32) # meaning sequence end ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len ms_height = np.zeros((size), dtype=np.int32) # meaning tree height ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight ms_data = np.zeros((datastep), dtype=np.int32) # meaning sequence ms_level = np.zeros((datastep), dtype=np.uint32) # meaning level, vocab's level is 0 ms_rank_idx = np.zeros((datastep), dtype=np.uint32) # meaning index of all level ms_rank_all = np.zeros((datastep), dtype=np.uint32) # meaning all of all level index = 0 for i in range(self.vocab_size): ms_data[i] = i ms_start[i] = index ms_end[i] = index + 1 ms_len[i] = 1 ms_level[i] = 0 ms_rank_idx[i] = 0xFFFFFFF ms_rank_all[i] = 0xFFFFFFF ms_height[i] = 0 ms_weight[i] = 1 index = index + 1 for i in range(self.vocab_size, size): m = map[i] # 当前meaning的拆分的分支 m = m[m >= 0] # donot cut off the map such as [0] m_len = len(m) # 当前meaning的拆分的分支个数 m_list = m.tolist() assert m_list, "map list can not be empty list" # 获取每个子meaning的start和end,并且生成序列组合成当前meaning完整的叶index(= 0].tolist(): if m >= self.vocab_size: pn = NodeTree(str(m), parent) get_tree_node(ms_map, m, vocab_size, pn, seqlist) else: pn = NodeTree("<" + str(m) + ">", parent) seqlist.append(pn) root = NodeTree(str(meaning)) seqlist = [] get_tree_node(self.ms_map, meaning, self.vocab_size, root, seqlist) root.seq_node = seqlist return root # 返回每个token相对于上一个token的level变化 # 返回两个list,分别表示 common -> current -> common 两个变化的level距离 def get_level_change(self, meaning): def level_change(ms_map, meaning, current_to_common, common_to_current): ms = ms_map[meaning] for m in ms[ms >= 0].tolist(): if m >= self.vocab_size: common_to_current[-1] = common_to_current[-1] + 1 level_change(ms_map, m, current_to_common, common_to_current) else: current_to_common.append(0) common_to_current.append(0) current_to_common[-2] = current_to_common[-2] + 1 common_to_current = [] common_to_current.append(1) current_to_common = [] current_to_common.append(0) level_change(self.ms_map, meaning, current_to_common, common_to_current) current_to_common = current_to_common[:-1] common_to_current = common_to_current[:-1] return current_to_common, common_to_current # 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面token是否有关系 def get_relation_table(self, meaning): current_to_common, common_to_current = self.get_level_change(meaning) width = len(current_to_common) relation = np.zeros((width, width), dtype=int) relation[0, 0] = 1 for i in range(1, width, 1): if i == width - 2: print(1) ori = current_to_common[i] - common_to_current[i] start_index = width for s in range(i - 1, -1, -1): if ori < 0: break ori = ori - common_to_current[s] + current_to_common[s] start_index = s relation[i, start_index : i + 1] = 1.0 return relation def max_length(self): return max(self.ms_len) class MeaningDataset(Dataset): def __init__( self, start, end, vocab_size, size=None, max_subitem=10, min_subitem=1, min_seq_len=2, seed=42, use_cache=True, ): np.random.seed(seed) self.start = start self.end = end self.vocab_size = vocab_size self.max_subitem = max_subitem self.min_subitem = min_subitem self.use_cache = use_cache self.min_seq_len = min_seq_len print("Build MeaningDataset from MeaningMap.") self.val_mask_level = None self.val_mask_idx = None self.seq = [] self.level = [] self.rank_idx = [] self.rank_all = [] self.seq_meaning = [] map = self.get_meaning_map() self.m_height = map.ms_height self.m_weight = map.ms_weight if size: meanings = np.random.randint(start, end, size=(size)) else: meanings = np.arange(start, end) seq_len = [] for m in meanings: d, l, i, a = map.get_sequence(m) if len(d) >= min_seq_len: self.seq.append(d) self.level.append(l) self.seq_meaning.append(m) seq_len.append(len(d)) dm = np.ones(i.shape, dtype=np.uint32) dm = ((dm * 0xFFFFFFFF) << (l * 4)).astype(np.uint32) shift = (8 - l) * 4 rank_idx = (i & 0xF) << 28 rank_idx = rank_idx + ((i & 0xF0) << 20) rank_idx = rank_idx + ((i & 0xF00) << 12) rank_idx = rank_idx + ((i & 0xF000) << 4) rank_idx = rank_idx + ((i & 0xF0000) >> 4) rank_idx = rank_idx + ((i & 0xF00000) >> 12) rank_idx = rank_idx + ((i & 0xF000000) >> 20) rank_idx = rank_idx + ((i & 0xF0000000) >> 28) rank_idx = ((rank_idx >> shift) + dm).astype(np.uint32) rank_all = (a & 0xF) << 28 rank_all = rank_all + ((a & 0xF0) << 20) rank_all = rank_all + ((a & 0xF00) << 12) rank_all = rank_all + ((a & 0xF000) << 4) rank_all = rank_all + ((a & 0xF0000) >> 4) rank_all = rank_all + ((a & 0xF00000) >> 12) rank_all = rank_all + ((a & 0xF000000) >> 20) rank_all = rank_all + ((a & 0xF0000000) >> 28) rank_all = ((rank_all >> shift) + dm).astype(np.uint32) self.rank_idx.append(rank_idx) self.rank_all.append(rank_all) unique, counts = np.unique(seq_len, return_counts=True) print("Build MeaningDataset end.") print("----------------------------------------------------------------") print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start)) print("MeaningDataset size:" + str(len(seq_len))) print("MeaningDataset max sequence length:" + str(max(unique))) print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)])) print("----------------------------------------------------------------") def __len__(self): return len(self.seq) def len(self): return len(self.seq) def get_meaning_map(self): return MeaningMap(self.end, self.vocab_size, self.max_subitem, self.min_subitem, self.use_cache) def set_mask(self, level=None, idx=None): if self.val_mask_level is not None and self.val_mask_idx is not None: assert len(self.val_mask_level) > 0, "len must > 0" assert len(self.val_mask_level) == len(self.val_mask_idx), "mask level and mask index must be same length" assert isinstance(self.val_mask_level, list), "mask level must be list" assert isinstance(self.val_mask_idx, list), "mask index must be list" self.val_mask_level = level self.val_mask_idx = idx def __getitem__(self, idx): return self.get_batch([idx]) def get_batch(self, idx_list): # must equal sequence length data = [self.seq[i] for i in idx_list] output = {} data = torch.tensor(np.stack(data, axis=0)).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) output["val_mask"] = self.get_seq_mask_tensor(idx_list) output["meaning"] = [self.seq_meaning[i] for i in idx_list] return output def get_token(self, idx): # must equal sequence length return self.seq[idx] def get_meaning(self, idx): return self.seq_meaning[idx] def copy(self, start, end): new = copy.deepcopy(self) new.seq = new.seq[start:end] new.level = new.level[start:end] new.rank_idx = new.rank_idx[start:end] new.rank_all = new.rank_all[start:end] new.seq_meaning = new.seq_meaning[start:end] new.val_mask_level = self.val_mask_level new.val_mask_idx = self.val_mask_idx return new def split(self, ratio): l = self.len() middle = int(l * ratio) return self.copy(0, middle), self.copy(middle, l) def get_seq_mask(self, idx, level, index): # assert index < 15, "index must < 15" # assert level < 8, "level must < 8" rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF return rank_idx == (rank_all + index if index < 0 else index) def get_seq_mask_tensor(self, idx_list): if self.val_mask_level is not None and self.val_mask_idx is not None: mask = torch.tensor( np.stack( [self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0 ) ) for i, l in enumerate(self.val_mask_level[1:]): mask = mask & torch.tensor( np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0) ) return mask else: return None class BatchGroupMeaningDataloader(Dataset): def __init__(self, meaning_dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True): self.meaning_dataset = meaning_dataset self.batch_size = batch_size self.drop_last = drop_last seq_len = [len(s) for s in meaning_dataset.seq] unique, counts = np.unique(seq_len, return_counts=True) gl = {} for u in unique: gl[u] = np.where(seq_len == u)[0] lens = list(gl.keys()) gs = {} if shuffle: for k in gl.keys(): sl = gl[k].copy() np.random.shuffle(sl) gs[k] = sl else: for k in gl.keys(): sl = gl[k].copy() gs[k] = sl index = np.zeros((0, batch_size), dtype=np.int64) for l in lens: batch = len(gs[l]) // batch_size new = gs[l][0 : batch * batch_size].reshape(batch, batch_size) index = np.concatenate((index, new), axis=0) if shuffle: index_shuffle = np.arange(0, index.shape[0]) np.random.shuffle(index_shuffle) index = index[index_shuffle] self.indexBatch = index print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index))) print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size)) def __len__(self): return len(self.indexBatch) def __getitem__(self, idx): return self.meaning_dataset.get_batch(self.indexBatch[idx]) def detection_collate(batch): return batch[0] def dataloader(self, num_workers=1): return DataLoader( self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate ) if __name__ == "__main__": md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) md.set_mask([1], [-1]) train, val = md.split(0.95) fdaf = md.__getitem__(920) print(md.print_tree(920)) print(md.rank_idx[920]) print(md.rank_all[920]) mask = md.get_seq_mask(920, 0, -1) print(mask) mask = md.get_seq_mask(920, 1, 0) print(mask) mask = md.get_seq_mask(920, 1, -1) print(mask) mask = md.get_seq_mask(920, 1, 1) print(mask) assert all( np.equal( mask[0:57], np.array( [ False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, True, False, False, False, False, False, False, False, True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, ] ), ) ), "False" md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False) md.set_mask([0, 1], [0, -1]) dl = BatchGroupMeaningDataloader(md, 1) length = len(dl) it = iter(dl) ne1 = next(it) tree = ne1["tree"] mask = ne1["mask"].cpu().numpy() t = MeaningMap.print_tree(tree) print(t) m, l = MeaningMap.get_tree_indexed_str(tree, mask, "") print(m) ne2 = next(it) ne3 = next(it)