Update meaning dataset define.

This commit is contained in:
Colin 2024-03-26 11:32:02 +08:00
parent e29c0b9a41
commit b0ca4dc35d
2 changed files with 12 additions and 6 deletions

4
.gitignore vendored
View File

@ -6,3 +6,7 @@ temp
lightning_logs lightning_logs
checkpoints checkpoints
build
log
logs
data

View File

@ -17,11 +17,15 @@ class MeaningMap: # 16777216 1048576 8192
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_subitem = max_subitem self.max_subitem = max_subitem
path = "./data/"
file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem)
file = path + file
file_start = file + "_start" + ".npy" file_start = file + "_start" + ".npy"
file_len = file + "_len" + ".npy" file_len = file + "_len" + ".npy"
file_data = file + "_data" + ".npy" file_data = file + "_data" + ".npy"
if not os.path.exists(path):
os.mkdir(path)
if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data): if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data):
print("Load from disk cache: " + file) print("Load from disk cache: " + file)
self.ms_data = np.load(file_data) self.ms_data = np.load(file_data)
@ -133,12 +137,10 @@ class MeaningDataset(Dataset):
output["token_type_ids"] = torch.zeros(data.shape) output["token_type_ids"] = torch.zeros(data.shape)
return output return output
def GetBatch(self, index_list): def GetBatch(self, index_list): # must equal sequence length
data = [] data = [self.data[i] for i in index_list]
for i in index_list:
data.append(self.data[i])
output = {} output = {}
data = torch.tensor(data).long() data = torch.tensor(np.stack(data, axis=0)).long()
output["input_ids"] = data output["input_ids"] = data
output["labels"] = data.clone() output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape) output["token_type_ids"] = torch.zeros(data.shape)