133 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			133 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
| import os
 | |
| import datasets
 | |
| import torch
 | |
| import math
 | |
| import random
 | |
| from itertools import chain
 | |
| from typing import Dict, Tuple
 | |
| from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
 | |
| import numpy as np
 | |
| 
 | |
| 
 | |
| class MeaningMap:  # 16777216 1048576 8192
 | |
| 
 | |
|     def __init__(self, size=1048576, vocab_size=4096, max_subitem=10):
 | |
|         self.size = size
 | |
|         self.vocab_size = vocab_size
 | |
|         self.max_subitem = max_subitem
 | |
| 
 | |
|         file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
 | |
|         file_start = file + "_start" + ".npy"
 | |
|         file_len = file + "_len" + ".npy"
 | |
|         file_data = file + "_data" + ".npy"
 | |
| 
 | |
|         if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
 | |
|             print("Load from disk cache: " + file)
 | |
|             self.ms_data = np.load(file_data)
 | |
|             self.ms_start = np.load(file_start)
 | |
|             self.ms_len = np.load(file_len)
 | |
|             return None
 | |
| 
 | |
|         print("Disk cache miss, build new one.")
 | |
| 
 | |
|         mm = np.empty((size, max_subitem), dtype=np.int32)
 | |
|         total_level = int(math.log(size / vocab_size, max_subitem))
 | |
| 
 | |
|         start = [0]
 | |
|         end = [vocab_size]
 | |
|         shift = vocab_size
 | |
|         for i in range(total_level):
 | |
|             shift = end[-1]
 | |
|             start.append(end[-1])
 | |
|             end.append(shift * self.max_subitem)
 | |
|         start.append(end[-1])
 | |
|         end.append(size)
 | |
| 
 | |
|         index = np.arange(0, size)
 | |
|         mm = np.random.random((size, max_subitem))
 | |
| 
 | |
|         mask_zero = mm.copy()
 | |
|         mask_zero[:, 0] = 0.0
 | |
|         mask_zero.sort(axis=1)
 | |
|         thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1)
 | |
|         mask_zero = mask_zero > thre
 | |
| 
 | |
|         item_sum = mm.sum(axis=1)
 | |
|         scale = (index / item_sum).reshape(-1, 1).repeat(max_subitem, axis=1)
 | |
|         mm = mm * scale
 | |
|         mm[mask_zero] = 0
 | |
| 
 | |
|         mm[:vocab_size, 0] = np.arange(0, vocab_size)
 | |
|         mm[:vocab_size, 1:] = 0
 | |
|         mm = mm.astype(np.int32)
 | |
| 
 | |
|         ms = []  # meaning sequence
 | |
|         ms_start = []  # meaning sequence start
 | |
|         ms_len = []  # meaning sequence length
 | |
|         index = 0
 | |
|         for i in range(self.vocab_size):
 | |
|             ms.append([i])
 | |
|             ms_start.append(index)
 | |
|             ms_len.append(1)
 | |
|             index = index + 1
 | |
| 
 | |
|         for i in range(self.vocab_size, size):
 | |
|             m = mm[i]
 | |
|             m = m[m > 0]
 | |
|             ma = []
 | |
|             for newm in m.tolist():
 | |
|                 ma = ma + ms[newm]
 | |
|             ms.append(ma)
 | |
|             ms_start.append(index)
 | |
|             ms_len.append(len(ma))
 | |
|             index = index + len(ma)
 | |
| 
 | |
|         ms_data = list(chain(*ms))
 | |
|         np.save(file_data, np.array(ms_data).astype(np.int32))
 | |
|         np.save(file_start, np.array(ms_start).astype(np.int32))
 | |
|         np.save(file_len, np.array(ms_len).astype(np.int32))
 | |
| 
 | |
|         self.ms_data = ms_data
 | |
|         self.ms_start = ms_start
 | |
|         self.ms_len = ms_len
 | |
|         print("Disk cache build end.")
 | |
| 
 | |
|     def GetSequence(self, meaning):
 | |
|         start = self.ms_start[meaning]
 | |
|         len = self.ms_len[meaning]
 | |
|         return self.ms_data[start : start + len]
 | |
| 
 | |
| 
 | |
| class MeaningDataset(Dataset):
 | |
| 
 | |
|     def __init__(self, start=131072, end=1048576, size=32768, vocab_size=4096, max_subitem=10, seed=42):
 | |
|         self.seed = seed
 | |
|         np.random.seed(seed)
 | |
|         self.size = size
 | |
|         self.mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem)  # 1048576
 | |
|         self.data = []
 | |
|         meanings = np.random.randint(start, end, size=(size))
 | |
|         for m in meanings:
 | |
|             self.data.append(self.mm.GetSequence(m))
 | |
| 
 | |
|     def __len__(self):
 | |
|         return self.size
 | |
| 
 | |
|     def __getitem__(self, idx):
 | |
|         output = {}
 | |
|         data = torch.tensor(self.data[idx]).long()
 | |
|         output["input_ids"] = data
 | |
|         output["labels"] = data.clone()
 | |
|         output["token_type_ids"] = torch.zeros(data.shape)
 | |
|         return output
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
| 
 | |
|     md = MeaningDataset(4096, 4100, size=32768)
 | |
|     it = iter(md)
 | |
|     for i in range(10):
 | |
|         daf = next(it)["input_ids"].numpy().tolist()
 | |
| 
 | |
|         print(daf)
 |