diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index 9d52122..0fbf344 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -10,7 +10,7 @@ import numpy as np from torch.utils.data import BatchSampler -class MeaningMap: # 16777216 1048576 8192 +class MeaningMap: def __init__(self, size=1048576, vocab_size=4096, max_subitem=10): self.size = size @@ -20,14 +20,21 @@ class MeaningMap: # 16777216 1048576 8192 path = "./data/" file = "structured_language_" + str(size) + "_" + str(vocab_size) + "_" + str(max_subitem) file = path + file + file_map = file + "_map" + ".npy" file_start = file + "_start" + ".npy" file_len = file + "_len" + ".npy" file_data = file + "_data" + ".npy" if not os.path.exists(path): os.mkdir(path) - if os.path.exists(file_start) and os.path.exists(file_len) and os.path.exists(file_data): + if ( + os.path.exists(file_start) + and os.path.exists(file_len) + and os.path.exists(file_data) + and os.path.exists(file_map) + ): print("Load from disk cache: " + file) + self.ms_map = np.load(file_map) self.ms_data = np.load(file_data) self.ms_start = np.load(file_start) self.ms_len = np.load(file_len) @@ -77,21 +84,30 @@ class MeaningMap: # 16777216 1048576 8192 index = index + len(ma) ms_data = list(chain(*ms)) + np.save(file_map, np.array(mm).astype(np.int32)) np.save(file_data, np.array(ms_data).astype(np.int32)) np.save(file_start, np.array(ms_start).astype(np.int32)) np.save(file_len, np.array(ms_len).astype(np.int32)) + self.ms_map = mm self.ms_data = ms_data self.ms_start = ms_start self.ms_len = ms_len print("Disk cache build end.") - def GetSequence(self, meaning): + def get_sequence(self, meaning): start = self.ms_start[meaning] len = self.ms_len[meaning] return self.ms_data[start : start + len] - def MaxLength(self): + def get_mapping(self, meaning): + mapping = {} + ms = self.ms_map[meaning] + for m in ms[ms > 0].tolist(): + mapping[m] = self.get_mapping(m) if m >= self.vocab_size else m + return mapping + + def max_length(self): return max(self.ms_len) @@ -108,19 +124,23 @@ class MeaningDataset(Dataset): seed=42, data=None, length=None, + mapping=None, ): - if data != None and length != None: + if data != None and length != None and mapping != None: self.data = data self.length = length + self.mapping = mapping return np.random.seed(seed) mm = MeaningMap(size=end, vocab_size=vocab_size, max_subitem=max_subitem) # 1048576 + self.mapping = [] self.data = [] self.length = [] meanings = np.random.randint(start, end, size=(size)) for m in meanings: - sq = mm.GetSequence(m) + sq = mm.get_sequence(m) if len(sq) >= min_seq_len: + self.mapping.append(mm.get_mapping(m)) self.data.append(sq) self.length.append(len(sq)) @@ -146,7 +166,7 @@ class MeaningDataset(Dataset): output["token_type_ids"] = torch.zeros(data.shape) return output - def GetBatch(self, index_list): # must equal sequence length + def get_batch(self, index_list): # must equal sequence length data = [self.data[i] for i in index_list] output = {} data = torch.tensor(np.stack(data, axis=0)).long() @@ -155,13 +175,17 @@ class MeaningDataset(Dataset): output["token_type_ids"] = torch.zeros(data.shape) return output - def Split(self, ratio): + def get_mapping_batch(self, index_list): + return [self.mapping[i] for i in index_list] + + def split(self, ratio): l = len(self.data) middle = int(l * ratio) d_shuffle = self.data.copy() l_shuffle = self.length.copy() - md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle]) - md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:]) + m_shuffle = self.mapping.copy() + md1 = MeaningDataset(data=d_shuffle[:middle], length=l_shuffle[:middle], mapping=m_shuffle[:middle]) + md2 = MeaningDataset(data=d_shuffle[middle:], length=l_shuffle[middle:], mapping=m_shuffle[middle:]) return md1, md2 @@ -195,6 +219,7 @@ class BatchGroupMeaningDataloader(Dataset): batch = len(gs[l]) // batch_size new = gs[l][0 : batch * batch_size].reshape(batch, batch_size) index = np.concatenate((index, new), axis=0) + if shuffle: index_shuffle = np.arange(0, index.shape[0]) np.random.shuffle(index_shuffle) @@ -207,22 +232,26 @@ class BatchGroupMeaningDataloader(Dataset): return len(self.indexBatch) def __getitem__(self, idx): - # print("get idx" + str(idx)) - return self.dataset.GetBatch(self.indexBatch[idx]) + return self.dataset.get_batch(self.indexBatch[idx]) + + def mapping(self, idx): + return self.dataset.get_mapping_batch(self.indexBatch[idx]) if __name__ == "__main__": - md = MeaningDataset(4096, 8100, size=1024) - train, val = md.Split(0.95) + md = MeaningDataset(1024, 115200, vocab_size=1024, size=1024) + train, val = md.split(0.95) - dl = BatchGroupMeaningDataloader(train, 32) + dl = BatchGroupMeaningDataloader(train, 2) length = len(dl) it = iter(dl) ne1 = next(it) ne2 = next(it) ne3 = next(it) + map = dl.mapping(0) + dl = DataLoader( train, num_workers=1, diff --git a/wit/train.py b/wit/train.py index 955c13d..a59b378 100644 --- a/wit/train.py +++ b/wit/train.py @@ -54,7 +54,7 @@ if __name__ == "__main__": end = start * level_ratio size = int(vocab_size * (level_ratio**dataset_level)) raw_dataset = MeaningDataset(start, end, size, vocab_size, level_ratio) - train_dataset, val_dataset = raw_dataset.Split(0.9) + train_dataset, val_dataset = raw_dataset.split(0.9) train_dataloader = BatchGroupMeaningDataloader(train_dataset, train_batch_size) val_dataloader = BatchGroupMeaningDataloader(val_dataset, val_batch_size) # it = iter(train_dataloader)