import argparse from functools import partial from itertools import chain from typing import Dict, Tuple import datasets import pytorch_lightning as pl import torch from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset class SpecialDataset(Dataset): def __init__(self, start=1, end=128, size=1048576): # 1048576 32768 self.size = size self.features = [] a = torch.randint(start, end, [size]) b = torch.randint(start, end, [size]) c = torch.randint(start, end, [size]) d = torch.randint(start, end, [size]) z = torch.zeros([size]).long() # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0) # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long() self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0).long() # self.data = torch.stack([a, b, a]).permute(1, 0).long() # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long() # input a b c # output b c x # label a b c # a = torch.randint(start, end, [size]) # self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5 # self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1 # 只能有一种算法,而且第一个值不能用于训练 # 太陡峭的过度导致难以拟合 # 搜索空间太大,难以拟合 def __len__(self): return self.size def __getitem__(self, idx): output = {} data = self.data[idx] output["input_ids"] = data output["labels"] = data.clone() # output["labels"][:2] = 0 output["labels"][:2] = 256 output["token_type_ids"] = torch.zeros(data.shape) return output