diff --git a/wit/lit_module.py b/wit/lit_module.py index 5e3852c..8bfeab0 100644 --- a/wit/lit_module.py +++ b/wit/lit_module.py @@ -63,12 +63,14 @@ class LitModule(pl.LightningModule): logits = logits.contiguous().view(-1, logits.size(-1)) labels = batch["labels"][..., 1:] labels = labels.contiguous().view(-1) - label_mask = batch["mask"][..., 1:] - label_mask = label_mask.contiguous().view(-1) - logits_m = logits[label_mask] - labels_m = labels[label_mask] - self.metric_accuracy.update(logits_m, labels_m) - self.metric_loss.update(loss) + if batch["mask"] != None: + label_mask = batch["mask"][..., 1:] + label_mask = label_mask.contiguous().view(-1) + logits = logits[label_mask] + labels = labels[label_mask] + if logits.numel() != 0 and labels.numel() != 0: + self.metric_accuracy.update(logits, labels) + self.metric_loss.update(loss) def on_validation_epoch_end(self) -> None: self.log("val_loss", self.metric_loss, rank_zero_only=True) diff --git a/wit/meaning_dataset.py b/wit/meaning_dataset.py index 7b49792..cd57f59 100644 --- a/wit/meaning_dataset.py +++ b/wit/meaning_dataset.py @@ -103,11 +103,6 @@ class MeaningMap: for i, newm in enumerate(m_list) ] ) - # ml = ml[ma > 0] # cut off the 0 token, such as [12,32,0,42,32] - # mr = mr[ma > 0] - # mrl = mrl[ma > 0] - # ma = ma[ma > 0] - ms_data.append(ma) ms_level.append(ml) ms_rank_idx.append(mr) @@ -199,20 +194,53 @@ class MeaningMap: return max(self.ms_len) def get_tree_str(tree, prefix): - if isinstance(tree, dict): + if isinstance(tree, list): base = "" - last_is_dict = None - for key, value in tree.items(): - new_prefix = (len(str(key)) + 2) * " " + prefix - dict_string = MeaningMap.get_tree_str(value, new_prefix) - if dict_string: - base += "\n" + prefix + str(key) + ": " + dict_string - last_is_dict = True - else: - base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " " - last_is_dict = False + for t in tree: + base += MeaningMap.get_tree_str(t, "") return base - return None + else: + if isinstance(tree, dict): + base = "" + last_is_dict = None + for key, value in tree.items(): + new_prefix = (len(str(key)) + 2) * " " + prefix + dict_string = MeaningMap.get_tree_str(value, new_prefix) + if dict_string: + base += "\n" + prefix + str(key) + ": " + dict_string + last_is_dict = True + else: + base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " " + last_is_dict = False + return base + return None + + def get_tree_indexed_str(tree, data, prefix): + if isinstance(tree, list): + base = "" + qlen = 0 + for i, t in enumerate(tree): + s, l = MeaningMap.get_tree_indexed_str(t, data[i], "") + base += s + qlen += l + return (base, qlen) + else: + if isinstance(tree, dict): + base = "" + qlen = 0 + last_is_dict = None + for key, value in tree.items(): + new_prefix = (len(str(key)) + 2) * " " + prefix + dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix) + if dict_string: + base += "\n" + prefix + str(key) + ": " + dict_string + last_is_dict = True + else: + base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " " + last_is_dict = False + qlen += l + return (base, qlen) + return (None, 1) def token_frequency(tree, freq): if isinstance(tree, dict): @@ -280,22 +308,16 @@ class MeaningDataset(Dataset): return len(self.seq) def set_mask(self, level=None, idx=None): + if self.mask_level is not None and self.mask_idx is not None: + assert len(self.mask_level) > 0, "len must > 0" + assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length" + assert isinstance(self.mask_level, list), "mask level must be list" + assert isinstance(self.mask_idx, list), "mask index must be list" self.mask_level = level self.mask_idx = idx def __getitem__(self, idx): - output = {} - data = torch.tensor(self.seq[idx]).long() - output["input_ids"] = data - output["labels"] = data.clone() - output["token_type_ids"] = torch.zeros(data.shape) - output["tree"] = self.tree[idx] - output["level"] = self.level[idx] - if self.mask_level is not None and self.mask_idx is not None: - output["mask"] = torch.tensor(self.get_seq_mask(idx, self.mask_level, self.mask_idx)) - else: - output["mask"] = torch.ones(data.shape, dtype=torch.long) - return output + return self.get_batch([idx]) def get_batch(self, idx_list): # must equal sequence length data = [self.seq[i] for i in idx_list] @@ -306,12 +328,7 @@ class MeaningDataset(Dataset): output["token_type_ids"] = torch.zeros(data.shape) output["tree"] = [self.tree[i] for i in idx_list] output["level"] = [self.level[i] for i in idx_list] - if self.mask_level is not None and self.mask_idx is not None: - output["mask"] = torch.tensor( - np.stack([self.get_seq_mask(i, self.mask_level, self.mask_idx) for i in idx_list], axis=0) - ) - else: - output["mask"] = torch.ones(data.shape, dtype=torch.long) + output["mask"] = self.get_seq_mask_tensor(idx_list) return output def get_token(self, idx): # must equal sequence length @@ -335,6 +352,8 @@ class MeaningDataset(Dataset): new.rank_idx = new.rank_idx[start:end] new.rank_all = new.rank_all[start:end] new.seq_meaning = new.seq_meaning[start:end] + new.mask_level = self.mask_level + new.mask_idx = self.mask_idx return new def split(self, ratio): @@ -355,6 +374,19 @@ class MeaningDataset(Dataset): rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF return rank_idx == (rank_all + index if index < 0 else index) + def get_seq_mask_tensor(self, idx_list): + if self.mask_level is not None and self.mask_idx is not None: + mask = torch.tensor( + np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0) + ) + for i, l in enumerate(self.mask_level[1:]): + mask = mask & torch.tensor( + np.stack([self.get_seq_mask(idx, l, self.mask_idx[i]) for idx in idx_list], axis=0) + ) + return mask + else: + return None + class BatchGroupMeaningDataloader(Dataset): def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True): @@ -414,7 +446,8 @@ class BatchGroupMeaningDataloader(Dataset): if __name__ == "__main__": - md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False) + md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True) + md.set_mask([1], [-1]) train, val = md.split(0.95) fdaf = md.__getitem__(920) print(md.print_tree(920)) @@ -496,12 +529,18 @@ if __name__ == "__main__": ), "False" freq = md.token_frequency() - dl = BatchGroupMeaningDataloader(train, 2) - # length = len(dl) - # it = iter(dl) - # ne1 = next(it) - # ne2 = next(it) - # ne3 = next(it) + dl = BatchGroupMeaningDataloader(val, 1) + length = len(dl) + it = iter(dl) + ne1 = next(it) + tree = ne1["tree"] + mask = ne1["mask"].cpu().numpy() + t = MeaningMap.get_tree_str(tree, "") + print(t) + m, l = MeaningMap.get_tree_indexed_str(tree, mask, "") + print(m) + ne2 = next(it) + ne3 = next(it) # map1 = dl.get_tree(0) # map2 = dl.get_tree(1) diff --git a/wit/train.py b/wit/train.py index 780392e..db8297e 100644 --- a/wit/train.py +++ b/wit/train.py @@ -34,11 +34,16 @@ hidden_size = 1024 # 128 1024 2048 32 num_attention_heads = 16 # 8 8 16 num_hidden_layers = 3 # 6 12 24 3 -mask_level = None -mask_idx = None +mask_level = [0, 1] +mask_idx = [0, 0] + +# mask_level = [0, 1] +# mask_idx = [0, -1] # name = "vocab_ratio_level_data_hidden_head_layer" -name = "rank" +# name = "mask_level_idx" +name = "single_token" + ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}" ver = ver + "_" + f"{hidden_size}" + "_" + f"{num_attention_heads}" + "_" + f"{num_hidden_layers}" ver = ver + "_" + f"{mask_level}" + "_" + f"{mask_idx}"