diff --git a/wit/special_dataset.py b/wit/special_dataset.py index d69699b..c55a6fd 100644 --- a/wit/special_dataset.py +++ b/wit/special_dataset.py @@ -10,7 +10,7 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, S class SpecialDataset(Dataset): - def __init__(self, start=1, end=128, size=32768): # 1048576 32768 + def __init__(self, start=1, end=128, size=1048576): # 1048576 32768 self.size = size self.features = [] a = torch.randint(start, end, [size]) @@ -20,7 +20,7 @@ class SpecialDataset(Dataset): z = torch.zeros([size]).long() # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0) # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long() - self.data = torch.stack([a, a, a + a]).permute(1, 0).long() + self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0).long() # self.data = torch.stack([a, b, a]).permute(1, 0).long() # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long() @@ -44,6 +44,6 @@ class SpecialDataset(Dataset): output["input_ids"] = data output["labels"] = data.clone() # output["labels"][:2] = 0 - # output["labels"][:2] = vocab_size + output["labels"][:2] = 256 output["token_type_ids"] = torch.zeros(data.shape) return output