Update meaning dataset define.
This commit is contained in:
parent
e29c0b9a41
commit
b0ca4dc35d
|
@ -6,3 +6,7 @@ temp
|
|||
lightning_logs
|
||||
|
||||
checkpoints
|
||||
build
|
||||
log
|
||||
logs
|
||||
data
|
|
@ -17,11 +17,15 @@ class MeaningMap: # 16777216 1048576 8192
|
|||
self.vocab_size = vocab_size
|
||||
self.max_subitem = max_subitem
|
||||
|
||||
path = "./data/"
|
||||
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
||||
file = path + file
|
||||
file_start = file + "_start" + ".npy"
|
||||
file_len = file + "_len" + ".npy"
|
||||
file_data = file + "_data" + ".npy"
|
||||
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
|
||||
print("Load from disk cache: " + file)
|
||||
self.ms_data = np.load(file_data)
|
||||
|
@ -133,12 +137,10 @@ class MeaningDataset(Dataset):
|
|||
output["token_type_ids"] = torch.zeros(data.shape)
|
||||
return output
|
||||
|
||||
def GetBatch(self, index_list):
|
||||
data = []
|
||||
for i in index_list:
|
||||
data.append(self.data[i])
|
||||
def GetBatch(self, index_list): # must equal sequence length
|
||||
data = [self.data[i] for i in index_list]
|
||||
output = {}
|
||||
data = torch.tensor(data).long()
|
||||
data = torch.tensor(np.stack(data, axis=0)).long()
|
||||
output["input_ids"] = data
|
||||
output["labels"] = data.clone()
|
||||
output["token_type_ids"] = torch.zeros(data.shape)
|
||||
|
|
Loading…
Reference in New Issue