Update meaning dataset define.
This commit is contained in:
parent
e29c0b9a41
commit
d6c78ecd68
|
@ -17,11 +17,15 @@ class MeaningMap: # 16777216 1048576 8192
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.max_subitem = max_subitem
|
self.max_subitem = max_subitem
|
||||||
|
|
||||||
|
path = "./data/"
|
||||||
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
|
||||||
|
file = path + file
|
||||||
file_start = file + "_start" + ".npy"
|
file_start = file + "_start" + ".npy"
|
||||||
file_len = file + "_len" + ".npy"
|
file_len = file + "_len" + ".npy"
|
||||||
file_data = file + "_data" + ".npy"
|
file_data = file + "_data" + ".npy"
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.mkdir(path)
|
||||||
if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
|
if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
|
||||||
print("Load from disk cache: " + file)
|
print("Load from disk cache: " + file)
|
||||||
self.ms_data = np.load(file_data)
|
self.ms_data = np.load(file_data)
|
||||||
|
@ -133,12 +137,10 @@ class MeaningDataset(Dataset):
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def GetBatch(self, index_list):
|
def GetBatch(self, index_list): # must equal sequence length
|
||||||
data = []
|
data = [self.data[i] for i in index_list]
|
||||||
for i in index_list:
|
|
||||||
data.append(self.data[i])
|
|
||||||
output = {}
|
output = {}
|
||||||
data = torch.tensor(data).long()
|
data = torch.tensor(np.stack(data, axis=0)).long()
|
||||||
output["input_ids"] = data
|
output["input_ids"] = data
|
||||||
output["labels"] = data.clone()
|
output["labels"] = data.clone()
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
|
|
Loading…
Reference in New Issue