Refine meaning dataset.
This commit is contained in:
parent
0b19fd576a
commit
bca06af2dc
|
@ -353,7 +353,7 @@ class MeaningDataset(Dataset):
|
||||||
output["labels"] = data.clone()
|
output["labels"] = data.clone()
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
output["tree"] = [self.tree[i] for i in idx_list]
|
output["tree"] = [self.tree[i] for i in idx_list]
|
||||||
output["mask"] = self.get_seq_mask_tensor(idx_list)
|
output["val_mask"] = self.get_seq_mask_tensor(idx_list)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_token(self, idx): # must equal sequence length
|
def get_token(self, idx): # must equal sequence length
|
||||||
|
|
|
@ -59,8 +59,8 @@ class LitModule(pl.LightningModule):
|
||||||
logits = logits.contiguous().view(-1, logits.size(-1))
|
logits = logits.contiguous().view(-1, logits.size(-1))
|
||||||
labels = batch["labels"][..., 1:]
|
labels = batch["labels"][..., 1:]
|
||||||
labels = labels.contiguous().view(-1)
|
labels = labels.contiguous().view(-1)
|
||||||
if "mask" in batch and batch["mask"] != None:
|
if "val_mask" in batch and batch["val_mask"] != None:
|
||||||
label_mask = batch["mask"][..., 1:]
|
label_mask = batch["val_mask"][..., 1:]
|
||||||
label_mask = label_mask.contiguous().view(-1)
|
label_mask = label_mask.contiguous().view(-1)
|
||||||
logits = logits[label_mask]
|
logits = logits[label_mask]
|
||||||
labels = labels[label_mask]
|
labels = labels[label_mask]
|
||||||
|
|
|
@ -25,10 +25,10 @@ if __name__ == "__main__":
|
||||||
conf.strategy = "auto"
|
conf.strategy = "auto"
|
||||||
conf.resume_from_ckpt_path = None
|
conf.resume_from_ckpt_path = None
|
||||||
conf.seed = 42
|
conf.seed = 42
|
||||||
conf.dataloader_works = 4
|
conf.dataloader_works = 2
|
||||||
|
|
||||||
conf.dataset.meaning.mask_level = [0, 1, 2]
|
conf.dataset.meaning.mask_level = [0, 1, 2]
|
||||||
conf.dataset.meaning.mask_idx = [0, -1, -1]
|
conf.dataset.meaning.mask_idx = [0, 0, -1]
|
||||||
|
|
||||||
config.vocab_size = 256
|
config.vocab_size = 256
|
||||||
config.hidden_size = 128 # 128 1024 2048 32
|
config.hidden_size = 128 # 128 1024 2048 32
|
||||||
|
|
Loading…
Reference in New Issue