From b0ca4dc35d48106e627caadb4efeb53086afc4c6 Mon Sep 17 00:00:00 2001 From: Colin Date: Tue, 26 Mar 2024 11:32:02 +0800 Subject: [PATCH] Update meaning dataset define. --- .gitignore | 6 +++++- wit/meaning_dataset.py | 12 +++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index ecd3635..d4c498d 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,8 @@ __pycache__ temp lightning_logs -checkpoints \ No newline at end of file +checkpoints +build +log +logs +data \ No newline at end of file diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index 5f8f032..1e89819 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -17,11 +17,15 @@ class MeaningMap: # 16777216 1048576 8192 self.vocab_size = vocab_size self.max_subitem = max_subitem + path = "./data/" file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) + file = path + file file_start = file + "_start" + ".npy" file_len = file + "_len" + ".npy" file_data = file + "_data" + ".npy" + if not os.path.exists(path): + os.mkdir(path) if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data): print("Load from disk cache: " + file) self.ms_data = np.load(file_data) @@ -133,12 +137,10 @@ class MeaningDataset(Dataset): output["token_type_ids"] = torch.zeros(data.shape) return output - def GetBatch(self, index_list): - data = [] - for i in index_list: - data.append(self.data[i]) + def GetBatch(self, index_list): # must equal sequence length + data = [self.data[i] for i in index_list] output = {} - data = torch.tensor(data).long() + data = torch.tensor(np.stack(data, axis=0)).long() output["input_ids"] = data output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape)