50 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			50 lines
		
	
	
		
			1.8 KiB
		
	
	
	
		
			Python
		
	
	
	
| import argparse
 | |
| from functools import partial
 | |
| from itertools import chain
 | |
| from typing import Dict, Tuple
 | |
| 
 | |
| import datasets
 | |
| import pytorch_lightning as pl
 | |
| import torch
 | |
| from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset
 | |
| 
 | |
| 
 | |
| class SpecialDataset(Dataset):
 | |
|     def __init__(self, start=1, end=128, size=1048576):  # 1048576 32768
 | |
|         self.size = size
 | |
|         self.features = []
 | |
|         a = torch.randint(start, end, [size])
 | |
|         b = torch.randint(start, end, [size])
 | |
|         c = torch.randint(start, end, [size])
 | |
|         d = torch.randint(start, end, [size])
 | |
|         z = torch.zeros([size]).long()
 | |
|         # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
 | |
|         # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
 | |
|         self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0).long()
 | |
|         # self.data = torch.stack([a, b, a]).permute(1, 0).long()
 | |
|         # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
 | |
| 
 | |
|         # input  a b c
 | |
|         # output b c x
 | |
|         # label  a b c
 | |
| 
 | |
|         # a = torch.randint(start, end, [size])
 | |
|         # self.data = torch.stack([a, a, a + a]).permute(1, 0)  # accuracy=0.5
 | |
|         # self.data = torch.stack([a, a + a, a]).permute(1, 0)  # accuracy=1
 | |
|         # 只能有一种算法,而且第一个值不能用于训练
 | |
|         # 太陡峭的过度导致难以拟合
 | |
|         # 搜索空间太大,难以拟合
 | |
| 
 | |
|     def __len__(self):
 | |
|         return self.size
 | |
| 
 | |
|     def __getitem__(self, idx):
 | |
|         output = {}
 | |
|         data = self.data[idx]
 | |
|         output["input_ids"] = data
 | |
|         output["labels"] = data.clone()
 | |
|         # output["labels"][:2] = 0
 | |
|         output["labels"][:2] = 256
 | |
|         output["token_type_ids"] = torch.zeros(data.shape)
 | |
|         return output
 |