Update meaning dataset define.
This commit is contained in:
		
							parent
							
								
									e29c0b9a41
								
							
						
					
					
						commit
						b0ca4dc35d
					
				| 
						 | 
				
			
			@ -6,3 +6,7 @@ temp
 | 
			
		|||
lightning_logs
 | 
			
		||||
 | 
			
		||||
checkpoints
 | 
			
		||||
build
 | 
			
		||||
log
 | 
			
		||||
logs
 | 
			
		||||
data
 | 
			
		||||
| 
						 | 
				
			
			@ -17,11 +17,15 @@ class MeaningMap:  # 16777216 1048576 8192
 | 
			
		|||
        self.vocab_size = vocab_size
 | 
			
		||||
        self.max_subitem = max_subitem
 | 
			
		||||
 | 
			
		||||
        path = "./data/"
 | 
			
		||||
        file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
 | 
			
		||||
        file = path + file
 | 
			
		||||
        file_start = file + "_start" + ".npy"
 | 
			
		||||
        file_len = file + "_len" + ".npy"
 | 
			
		||||
        file_data = file + "_data" + ".npy"
 | 
			
		||||
 | 
			
		||||
        if not os.path.exists(path):
 | 
			
		||||
            os.mkdir(path)
 | 
			
		||||
        if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
 | 
			
		||||
            print("Load from disk cache: " + file)
 | 
			
		||||
            self.ms_data = np.load(file_data)
 | 
			
		||||
| 
						 | 
				
			
			@ -133,12 +137,10 @@ class MeaningDataset(Dataset):
 | 
			
		|||
        output["token_type_ids"] = torch.zeros(data.shape)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
    def GetBatch(self, index_list):
 | 
			
		||||
        data = []
 | 
			
		||||
        for i in index_list:
 | 
			
		||||
            data.append(self.data[i])
 | 
			
		||||
    def GetBatch(self, index_list):  # must equal sequence length
 | 
			
		||||
        data = [self.data[i] for i in index_list]
 | 
			
		||||
        output = {}
 | 
			
		||||
        data = torch.tensor(data).long()
 | 
			
		||||
        data = torch.tensor(np.stack(data, axis=0)).long()
 | 
			
		||||
        output["input_ids"] = data
 | 
			
		||||
        output["labels"] = data.clone()
 | 
			
		||||
        output["token_type_ids"] = torch.zeros(data.shape)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue