Update special dateset.
This commit is contained in:
parent
01e5f86e94
commit
c4f7ef2813
|
@ -10,7 +10,7 @@ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, S
|
||||||
|
|
||||||
|
|
||||||
class SpecialDataset(Dataset):
|
class SpecialDataset(Dataset):
|
||||||
def __init__(self, start=1, end=128, size=32768): # 1048576 32768
|
def __init__(self, start=1, end=128, size=1048576): # 1048576 32768
|
||||||
self.size = size
|
self.size = size
|
||||||
self.features = []
|
self.features = []
|
||||||
a = torch.randint(start, end, [size])
|
a = torch.randint(start, end, [size])
|
||||||
|
@ -20,7 +20,7 @@ class SpecialDataset(Dataset):
|
||||||
z = torch.zeros([size]).long()
|
z = torch.zeros([size]).long()
|
||||||
# self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
|
# self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
|
||||||
# self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
|
# self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
|
||||||
self.data = torch.stack([a, a, a + a]).permute(1, 0).long()
|
self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0).long()
|
||||||
# self.data = torch.stack([a, b, a]).permute(1, 0).long()
|
# self.data = torch.stack([a, b, a]).permute(1, 0).long()
|
||||||
# self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
|
# self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
|
||||||
|
|
||||||
|
@ -44,6 +44,6 @@ class SpecialDataset(Dataset):
|
||||||
output["input_ids"] = data
|
output["input_ids"] = data
|
||||||
output["labels"] = data.clone()
|
output["labels"] = data.clone()
|
||||||
# output["labels"][:2] = 0
|
# output["labels"][:2] = 0
|
||||||
# output["labels"][:2] = vocab_size
|
output["labels"][:2] = 256
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
return output
|
return output
|
||||||
|
|
Loading…
Reference in New Issue