50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
import argparse
|
|
from functools import partial
|
|
from itertools import chain
|
|
from typing import Dict, Tuple
|
|
|
|
import datasets
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset
|
|
|
|
|
|
class SpecialDataset(Dataset):
|
|
def __init__(self, start=1, end=128, size=1048576): # 1048576 32768
|
|
self.size = size
|
|
self.features = []
|
|
a = torch.randint(start, end, [size])
|
|
b = torch.randint(start, end, [size])
|
|
c = torch.randint(start, end, [size])
|
|
d = torch.randint(start, end, [size])
|
|
z = torch.zeros([size]).long()
|
|
# self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
|
|
# self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
|
|
self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0).long()
|
|
# self.data = torch.stack([a, b, a]).permute(1, 0).long()
|
|
# self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
|
|
|
|
# input a b c
|
|
# output b c x
|
|
# label a b c
|
|
|
|
# a = torch.randint(start, end, [size])
|
|
# self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5
|
|
# self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1
|
|
# 只能有一种算法,而且第一个值不能用于训练
|
|
# 太陡峭的过度导致难以拟合
|
|
# 搜索空间太大,难以拟合
|
|
|
|
def __len__(self):
|
|
return self.size
|
|
|
|
def __getitem__(self, idx):
|
|
output = {}
|
|
data = self.data[idx]
|
|
output["input_ids"] = data
|
|
output["labels"] = data.clone()
|
|
# output["labels"][:2] = 0
|
|
output["labels"][:2] = 256
|
|
output["token_type_ids"] = torch.zeros(data.shape)
|
|
return output
|