Fix dataset.
This commit is contained in:
		
							parent
							
								
									6791264987
								
							
						
					
					
						commit
						9926b893a4
					
				|  | @ -12,13 +12,17 @@ import copy | |||
| 
 | ||||
| 
 | ||||
| class MeaningMap: | ||||
|     def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, use_cache=True): | ||||
|     def __init__(self, size=1048576, vocab_size=4096, max_subitem=10, min_subitem=1, use_cache=True): | ||||
|         assert size > 0 and vocab_size > 0 and max_subitem > 0 and min_subitem > 0, "Invalid input" | ||||
|         assert min_subitem <= max_subitem, "Invalid input" | ||||
|         self.size = size | ||||
|         self.vocab_size = vocab_size | ||||
|         self.max_subitem = max_subitem | ||||
|         self.min_subitem = min_subitem | ||||
| 
 | ||||
|         path = "./data/" | ||||
|         file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) | ||||
|         file = "structured_language_" + str(size) + "_" + str(vocab_size) | ||||
|         file += "_" + str(max_subitem) + "_" + str(min_subitem) | ||||
|         file = path + file + ".npz" | ||||
| 
 | ||||
|         if not os.path.exists(path): | ||||
|  | @ -47,7 +51,7 @@ class MeaningMap: | |||
|             map = np.random.random((size, max_subitem)) | ||||
| 
 | ||||
|             mask_zero = map.copy() | ||||
|             mask_zero[:, 0] = 0.0 | ||||
|             mask_zero[:, 0:min_subitem] = 0.0 | ||||
|             mask_zero.sort(axis=1) | ||||
|             thre = np.random.random((size)).reshape(-1, 1).repeat(max_subitem, axis=1) | ||||
|             mask_zero = mask_zero > thre | ||||
|  | @ -261,12 +265,13 @@ class MeaningDataset(Dataset): | |||
|         size, | ||||
|         vocab_size, | ||||
|         max_subitem=10, | ||||
|         min_subitem=1, | ||||
|         min_seq_len=2, | ||||
|         seed=42, | ||||
|         use_cache=True, | ||||
|     ): | ||||
|         np.random.seed(seed) | ||||
|         map = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem, use_cache=use_cache) | ||||
|         map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache) | ||||
|         np.random.seed(seed) | ||||
| 
 | ||||
|         self.mask_level = None | ||||
|  | @ -381,7 +386,7 @@ class MeaningDataset(Dataset): | |||
|             ) | ||||
|             for i, l in enumerate(self.mask_level[1:]): | ||||
|                 mask = mask & torch.tensor( | ||||
|                     np.stack([self.get_seq_mask(idx, l, self.mask_idx[i]) for idx in idx_list], axis=0) | ||||
|                     np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0) | ||||
|                 ) | ||||
|             return mask | ||||
|         else: | ||||
|  | @ -529,7 +534,9 @@ if __name__ == "__main__": | |||
|     ), "False" | ||||
|     freq = md.token_frequency() | ||||
| 
 | ||||
|     dl = BatchGroupMeaningDataloader(val, 1) | ||||
|     md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, min_subitem=2, use_cache=False) | ||||
|     md.set_mask([0, 1], [0, -1]) | ||||
|     dl = BatchGroupMeaningDataloader(md, 1) | ||||
|     length = len(dl) | ||||
|     it = iter(dl) | ||||
|     ne1 = next(it) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue