Refine meaning dataset.

This commit is contained in:
Colin 2025-02-20 17:30:46 +08:00
parent 0b19fd576a
commit bca06af2dc
3 changed files with 5 additions and 5 deletions

View File

@ -353,7 +353,7 @@ class MeaningDataset(Dataset):
output["labels"] = data.clone() output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape) output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = [self.tree[i] for i in idx_list] output["tree"] = [self.tree[i] for i in idx_list]
output["mask"] = self.get_seq_mask_tensor(idx_list) output["val_mask"] = self.get_seq_mask_tensor(idx_list)
return output return output
def get_token(self, idx): # must equal sequence length def get_token(self, idx): # must equal sequence length

View File

@ -59,8 +59,8 @@ class LitModule(pl.LightningModule):
logits = logits.contiguous().view(-1, logits.size(-1)) logits = logits.contiguous().view(-1, logits.size(-1))
labels = batch["labels"][..., 1:] labels = batch["labels"][..., 1:]
labels = labels.contiguous().view(-1) labels = labels.contiguous().view(-1)
if "mask" in batch and batch["mask"] != None: if "val_mask" in batch and batch["val_mask"] != None:
label_mask = batch["mask"][..., 1:] label_mask = batch["val_mask"][..., 1:]
label_mask = label_mask.contiguous().view(-1) label_mask = label_mask.contiguous().view(-1)
logits = logits[label_mask] logits = logits[label_mask]
labels = labels[label_mask] labels = labels[label_mask]

View File

@ -25,10 +25,10 @@ if __name__ == "__main__":
conf.strategy = "auto" conf.strategy = "auto"
conf.resume_from_ckpt_path = None conf.resume_from_ckpt_path = None
conf.seed = 42 conf.seed = 42
conf.dataloader_works = 4 conf.dataloader_works = 2
conf.dataset.meaning.mask_level = [0, 1, 2] conf.dataset.meaning.mask_level = [0, 1, 2]
conf.dataset.meaning.mask_idx = [0, -1, -1] conf.dataset.meaning.mask_idx = [0, 0, -1]
config.vocab_size = 256 config.vocab_size = 256
config.hidden_size = 128 # 128 1024 2048 32 config.hidden_size = 128 # 128 1024 2048 32