diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index d0fb767..f94cc10 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -353,7 +353,7 @@ class MeaningDataset(Dataset): output["labels"] = data.clone() output["token_type_ids"] = torch.zeros(data.shape) output["tree"] = [self.tree[i] for i in idx_list] - output["mask"] = self.get_seq_mask_tensor(idx_list) + output["val_mask"] = self.get_seq_mask_tensor(idx_list) return output def get_token(self, idx): # must equal sequence length diff --git a/wit/model/lit_module.py b/wit/model/lit_module.py index 9c5a4f1..e2303da 100644 --- a/wit/model/lit_module.py +++ b/wit/model/lit_module.py @@ -59,8 +59,8 @@ class LitModule(pl.LightningModule): logits = logits.contiguous().view(-1, logits.size(-1)) labels = batch["labels"][..., 1:] labels = labels.contiguous().view(-1) - if "mask" in batch and batch["mask"] != None: - label_mask = batch["mask"][..., 1:] + if "val_mask" in batch and batch["val_mask"] != None: + label_mask = batch["val_mask"][..., 1:] label_mask = label_mask.contiguous().view(-1) logits = logits[label_mask] labels = labels[label_mask] diff --git a/wit/train.py b/wit/train.py index db1e249..cf900ca 100644 --- a/wit/train.py +++ b/wit/train.py @@ -25,10 +25,10 @@ if __name__ == "__main__": conf.strategy = "auto" conf.resume_from_ckpt_path = None conf.seed = 42 - conf.dataloader_works = 4 + conf.dataloader_works = 2 conf.dataset.meaning.mask_level = [0, 1, 2] - conf.dataset.meaning.mask_idx = [0, -1, -1] + conf.dataset.meaning.mask_idx = [0, 0, -1] config.vocab_size = 256 config.hidden_size = 128 # 128 1024 2048 32