import os import torch, datasets import math, gc, time, random, copy from itertools import chain from typing import Dict, Tuple from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split import numpy as np from torch.utils.data import BatchSampler # import warnings # warnings.filterwarnings("ignore", ".*does not have many workers.*") class MeaningMap: def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True): assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input" assert min_subitem <= max_subitem, "Invalid input" self.size = size self.vocab_size = vocab_size self.max_subitem = max_subitem self.min_subitem = min_subitem path = "./data/" file = "structured_language_" + str(size) + "_" + str(vocab_size) file += "_" + str(max_subitem) + "_" + str(min_subitem) file = path + file + ".npz" start_time = time.time() if not os.path.exists(path): os.mkdir(path) if os.path.exists(file) and use_cache: print("Load from disk cache: " + file) loaded = np.load(file) slhwm = loaded["slhwm"] dlra = loaded["dlra"] self.ms_map = slhwm[:, 4:] self.ms_data = dlra[:, 0] self.ms_start = slhwm[:, 0] self.ms_len = slhwm[:, 1] self.ms_level = dlra[:, 1] self.ms_rank_idx = dlra[:, 2].astype(np.uint32) self.ms_rank_all = dlra[:, 3].astype(np.uint32) self.ms_height = slhwm[:, 2] self.ms_weight = slhwm[:, 3] print("Load end, elapsed:" + str(time.time() - start_time) + "s") else: print("Disk cache miss, build new one.") map = np.empty((size, max_subitem), dtype=np.int32) index = np.arange(0, size) map = np.random.random((size, max_subitem)) mask_zero = map.copy() mask_zero[:, 0:min_subitem] = 0.0 mask_zero.sort(axis=1) thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) mask_zero = mask_zero > thre item_sum = map.sum(axis=1) scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1) map = (map * scale).astype(np.int32) map[mask_zero] = -1 map[:vocab_size, 0] = np.arange(0, vocab_size) map[:vocab_size, 1:] = -1 ms_level = [] # meaning level, vocab's level is 0 ms_rank_idx = [] # meaning index of all level ms_rank_all = [] # meaning all of all level ms_start = np.zeros((size), dtype=np.int32) # meaning sequence start ms_end = np.zeros((size), dtype=np.int32) # meaning sequence end ms_len = np.zeros((size), dtype=np.int32) # meaning sequence len ms_height = np.zeros((size), dtype=np.int32) # meaning tree height ms_weight = np.zeros((size), dtype=np.int32) # meaning tree weight ms_data = np.zeros((268435456), dtype=np.int32) # meaning sequence ms_level = np.zeros((268435456), dtype=np.int32) # meaning level, vocab's level is 0 ms_rank_idx = np.zeros((268435456), dtype=np.uint32) # meaning index of all level ms_rank_all = np.zeros((268435456), dtype=np.uint32) # meaning all of all level index = 0 for i in range(self.vocab_size): ms_data[i] = i ms_start[i] = index ms_end[i] = index + 1 ms_len[i] = 1 ms_height[i] = 0 ms_weight[i] = 1 index = index + 1 for i in range(self.vocab_size, size): m = map[i] m = m[m >= 0] # donot cut off the map such as [0] m_len = len(m) m_list = m.tolist() assert m_list, "map list can not be empty list" ma = np.concatenate([ms_data[ms_start[newm] : ms_end[newm]] for newm in m_list]) len_ma = len(ma) end = index + len_ma if ms_data.size < end: ms_data = np.concatenate([ms_data, np.zeros((268435456), dtype=np.int32)]) ms_level = np.concatenate([ms_level, np.zeros((268435456), dtype=np.int32)]) ms_rank_idx = np.concatenate([ms_rank_idx, np.zeros((268435456), dtype=np.uint32)]) ms_rank_all = np.concatenate([ms_rank_all, np.zeros((268435456), dtype=np.uint32)]) ms_data[index:end] = ma ms_level[index:end] = np.concatenate([ms_level[ms_start[newm] : ms_end[newm]] + 1 for newm in m_list]) ms_rank_idx[index:end] = np.concatenate( [ ( [0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[ms_start[newm] : ms_end[newm]] * 16 + i ) for i, newm in enumerate(m_list) ] ).astype(np.uint32) ms_rank_all[index:end] = np.concatenate( [ ( [0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[ms_start[newm] : ms_end[newm]] * 16 + m_len ) for i, newm in enumerate(m_list) ] ).astype(np.uint32) ms_start[i] = index ms_end[i] = end ms_len[i] = len_ma ms_height[i] = max([ms_height[sub_m] for sub_m in m_list]) + 1 ms_weight[i] = sum(ms_weight[sub_m] for sub_m in m_list) index = index + len_ma if i % 10000 == 0: print(i) print("Mapping end, elapsed:" + str(time.time() - start_time) + "s") d = np.ones(ms_rank_idx.shape, dtype=np.uint32) d = ((d * 0xFFFFFFFF) << (ms_level * 4)).astype(np.uint32) ms_rank_idx = ( ((ms_rank_idx & 0xF) << 28) + ((ms_rank_idx & 0xF0) << 20) + ((ms_rank_idx & 0xF00) << 12) + ((ms_rank_idx & 0xF000) << 4) + ((ms_rank_idx & 0xF0000) >> 4) + ((ms_rank_idx & 0xF00000) >> 12) + ((ms_rank_idx & 0xF000000) >> 20) + ((ms_rank_idx & 0xF0000000) >> 28) ) ms_rank_idx = ((ms_rank_idx >> ((8 - ms_level) * 4)) + d).astype(np.uint32) ms_rank_all = ( ((ms_rank_all & 0xF) << 28) + ((ms_rank_all & 0xF0) << 20) + ((ms_rank_all & 0xF00) << 12) + ((ms_rank_all & 0xF000) << 4) + ((ms_rank_all & 0xF0000) >> 4) + ((ms_rank_all & 0xF00000) >> 12) + ((ms_rank_all & 0xF000000) >> 20) + ((ms_rank_all & 0xF0000000) >> 28) ) ms_rank_all = ((ms_rank_all >> ((8 - ms_level) * 4)) + d).astype(np.uint32) ms_start = np.array(ms_start).astype(np.int32) ms_height = np.array(ms_height).astype(np.int32) ms_weight = np.array(ms_weight).astype(np.int32) ms_len = np.array(ms_len).astype(np.int32) slhwm = np.concatenate( ( ms_start.reshape((-1, 1)), ms_len.reshape((-1, 1)), ms_height.reshape((-1, 1)), ms_weight.reshape((-1, 1)), map, ), axis=1, ) dlra = np.stack((ms_data, ms_level, ms_rank_idx.astype(np.int32), ms_rank_all.astype(np.int32)), axis=1) np.savez(file, slhwm=slhwm, dlra=dlra) self.ms_data = ms_data # map[i]=ms_data[ms_start[i]:ms_start[i]+ms_len[i]] self.ms_level = ms_level self.ms_rank_idx = ms_rank_idx self.ms_rank_all = ms_rank_all self.ms_map = map # ms_map[i] = [sub(i),sub(i),sub(i),sub(i)...sub(i)] self.ms_start = ms_start self.ms_len = ms_len self.ms_height = ms_height self.ms_weight = ms_weight print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s") def get_sequence(self, meaning): # return sequence[meaning] start = self.ms_start[meaning] len = self.ms_len[meaning] return ( self.ms_data[start : start + len], self.ms_level[start : start + len], self.ms_rank_idx[start : start + len], self.ms_rank_all[start : start + len], ) def get_tree(self, meaning): # return meaning all sub items tree = {} ms = self.ms_map[meaning] for m in ms[ms > 0].tolist(): tree[m] = self.get_tree(m) if m >= self.vocab_size else m return tree def max_length(self): return max(self.ms_len) def get_tree_str(tree, prefix): if isinstance(tree, list): base = "" for t in tree: base += MeaningMap.get_tree_str(t, "") return base else: if isinstance(tree, dict): base = "" last_is_dict = None for key, value in tree.items(): new_prefix = (len(str(key)) + 2) * " " + prefix dict_string = MeaningMap.get_tree_str(value, new_prefix) if dict_string: base += "\n" + prefix + str(key) + ": " + dict_string last_is_dict = True else: base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " " last_is_dict = False return base return None def get_tree_indexed_str(tree, data, prefix): if isinstance(tree, list): base = "" qlen = 0 for i, t in enumerate(tree): s, l = MeaningMap.get_tree_indexed_str(t, data[i], "") base += s qlen += l return (base, qlen) else: if isinstance(tree, dict): base = "" qlen = 0 last_is_dict = None for key, value in tree.items(): new_prefix = (len(str(key)) + 2) * " " + prefix dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix) if dict_string: base += "\n" + prefix + str(key) + ": " + dict_string last_is_dict = True else: base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " " last_is_dict = False qlen += l return (base, qlen) return (None, 1) def token_frequency(tree, freq): if isinstance(tree, dict): for key, value in tree.items(): if key in freq: freq[key] = freq[key] + 1 else: freq[key] = 1 MeaningMap.token_frequency(value, freq) class MeaningDataset(Dataset): def __init__( self, start, end, size, vocab_size, max_subitem=10, min_subitem=1, min_seq_len=2, seed=42, use_cache=True, ): np.random.seed(seed) map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache) np.random.seed(seed) self.mask_level = None self.mask_idx = None self.tree = [] self.seq = [] self.level = [] self.rank_idx = [] self.rank_all = [] self.seq_meaning = [] self.m_height = map.ms_height self.m_weight = map.ms_weight meanings = np.random.randint(start, end, size=(size)) seq_len = [] for m in meanings: d, l, i, a = map.get_sequence(m) if len(d) >= min_seq_len: self.tree.append({m: map.get_tree(m)}) self.seq.append(d) self.level.append(l) self.rank_idx.append(i) self.rank_all.append(a) self.seq_meaning.append(m) seq_len.append(len(d)) unique, counts = np.unique(seq_len, return_counts=True) print("----------------------------------------------------------------") print("MeaningDataset start:" + str(start) + " end:" + str(end) + " space:" + str(end - start)) print("MeaningDataset size:" + str(len(seq_len))) print("MeaningDataset max sequence length:" + str(max(unique))) print("MeaningDataset most popular sequence length:" + str(unique[np.argmax(counts)])) print("----------------------------------------------------------------") def __len__(self): return len(self.seq) def len(self): return len(self.seq) def set_mask(self, level=None, idx=None): if self.mask_level is not None and self.mask_idx is not None: assert len(self.mask_level) > 0, "len must > 0" assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length" assert isinstance(self.mask_level, list), "mask level must be list" assert isinstance(self.mask_idx, list), "mask index must be list" self.mask_level = level self.mask_idx = idx def __getitem__(self, idx): return self.get_batch([idx]) def get_batch(self, idx_list): # must equal sequence length data = [self.seq[i] for i in idx_list] output = {} data = torch.tensor(np.stack(data, axis=0)).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) output["tree"] = [self.tree[i] for i in idx_list] output["level"] = [self.level[i] for i in idx_list] output["mask"] = self.get_seq_mask_tensor(idx_list) return output def get_token(self, idx): # must equal sequence length return self.seq[idx] def get_tree(self, idx): return self.tree[idx] def print_tree(self, idx): tokens = self.seq[idx] tree = self.get_tree(idx) s = str(tokens) + "\n" s += MeaningMap.get_tree_str(tree, "") return s def copy(self, start, end): new = copy.deepcopy(self) new.tree = new.tree[start:end] new.seq = new.seq[start:end] new.level = new.level[start:end] new.rank_idx = new.rank_idx[start:end] new.rank_all = new.rank_all[start:end] new.seq_meaning = new.seq_meaning[start:end] new.mask_level = self.mask_level new.mask_idx = self.mask_idx return new def split(self, ratio): l = self.len() middle = int(l * ratio) return self.copy(0, middle), self.copy(middle, l) def token_frequency(self): freq = {} for t in self.tree: MeaningMap.token_frequency(t, freq) return freq def get_seq_mask(self, idx, level, index): assert index < 15, "index must < 15" assert level < 8, "level must < 8" rank_idx = (self.rank_idx[idx] >> (4 * level)).astype(np.int32) & 0xF rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF return rank_idx == (rank_all + index if index < 0 else index) def get_seq_mask_tensor(self, idx_list): if self.mask_level is not None and self.mask_idx is not None: mask = torch.tensor( np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0) ) for i, l in enumerate(self.mask_level[1:]): mask = mask & torch.tensor( np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0) ) return mask else: return None class BatchGroupMeaningDataloader(Dataset): def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True): self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last seq_len = [len(s) for s in dataset.seq] unique, counts = np.unique(seq_len, return_counts=True) gl = {} for u in unique: gl[u] = np.where(seq_len == u)[0] lens = list(gl.keys()) gs = {} if shuffle: for k in gl.keys(): sl = gl[k].copy() np.random.shuffle(sl) gs[k] = sl else: for k in gl.keys(): sl = gl[k].copy() gs[k] = sl index = np.zeros((0, batch_size), dtype=np.int64) for l in lens: batch = len(gs[l]) // batch_size new = gs[l][0 : batch * batch_size].reshape(batch, batch_size) index = np.concatenate((index, new), axis=0) if shuffle: index_shuffle = np.arange(0, index.shape[0]) np.random.shuffle(index_shuffle) index = index[index_shuffle] self.indexBatch = index print("Dataloader batch size:" + str(batch_size) + " count:" + str(len(index))) print("Dataloader total:" + str(len(seq_len)) + " drop:" + str(len(seq_len) - len(index) * batch_size)) def __len__(self): return len(self.indexBatch) def __getitem__(self, idx): return self.dataset.get_batch(self.indexBatch[idx]) def get_tree(self, idx): return [self.dataset.get_tree(i) for i in self.indexBatch[idx]] def print_tree(self, idx): idx_list = self.indexBatch[idx] s = "--------------------------------------------------------\n" for i in idx_list: s += self.dataset.print_tree(i) s += "--------------------------------------------------------\n" return s def detection_collate(batch): return batch[0] def dataloader(self, num_workers=1): return DataLoader( self, batch_size=1, num_workers=num_workers, collate_fn=BatchGroupMeaningDataloader.detection_collate ) if __name__ == "__main__": md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) md.set_mask([1], [-1]) train, val = md.split(0.95) fdaf = md.__getitem__(920) print(md.print_tree(920)) print(md.rank_idx[920]) print(md.rank_all[920]) mask = md.get_seq_mask(920, 0, -1) print(mask) mask = md.get_seq_mask(920, 1, 0) print(mask) mask = md.get_seq_mask(920, 1, -1) print(mask) mask = md.get_seq_mask(920, 1, 1) print(mask) assert all( np.equal( mask[0:57], np.array( [ False, False, False, False, False, False, False, False, False, False, False, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, True, False, False, False, False, False, False, False, True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, ] ), ) ), "False" freq = md.token_frequency() md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False) md.set_mask([0, 1], [0, -1]) dl = BatchGroupMeaningDataloader(md, 1) length = len(dl) it = iter(dl) ne1 = next(it) tree = ne1["tree"] mask = ne1["mask"].cpu().numpy() t = MeaningMap.get_tree_str(tree, "") print(t) m, l = MeaningMap.get_tree_indexed_str(tree, mask, "") print(m) ne2 = next(it) ne3 = next(it) # map1 = dl.get_tree(0) # map2 = dl.get_tree(1) # print(dl.print_tree(0)) # dl = DataLoader( # train, # num_workers=1, # persistent_workers=True, # shuffle=False, # ) # it = iter(dl) # ne1 = next(it) # ne2 = next(it) # ne3 = next(it) # for i in range(10): # print(next(it)["input_ids"].numpy().tolist())