import os import torch import torchvision from torchvision import datasets, transforms import torchvision.models as models import numpy as np from torch.utils.data import Dataset, DataLoader, Subset import random def MNIST(batchsize=8, num_workers=0, shuffle=False): CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/" train_loader = torch.utils.data.DataLoader( datasets.MNIST(root=CurrentPath+'../Dataset/', train=True, download=True, transform=transforms.Compose([ transforms.ColorJitter(0.2, 0.2), transforms.RandomRotation(30), transforms.RandomResizedCrop(28), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True) return train_loader, test_loader def Cifar10(batchsize=8, num_workers=0, shuffle=False): CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/" train_loader = torch.utils.data.DataLoader( datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True) test_loader = torch.utils.data.DataLoader( datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True) return train_loader, test_loader def Cifar10Mono(batchsize=8, num_workers=0, shuffle=False, trainsize=0): CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/" dataset = datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=True, download=True, transform=transforms.Compose([ transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) if trainsize == 0: trainsize = dataset.data.shape[0] train_loader = torch.utils.data.DataLoader( Subset(dataset, range(0, trainsize)), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True) test_loader = torch.utils.data.DataLoader( datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([ transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True) return train_loader, test_loader def RandomMnist(batchsize=8, num_workers=0, shuffle=False, style=""): def default_loader(path): if path == "0": return np.random.randint(0, 255, (1, 28, 28)).astype("float32") if style == "Vertical": da = np.random.randint(0, 255, (28)).astype("float32") return np.array([[da for x in range(28)]], dtype="float32") if style == "VerticalOneLine": da = np.zeros((1, 28, 28)).astype("float32") da[0, :, 14] = 100 return da if style == "VerticalZebra": da = np.zeros((1, 28, 28)).astype("float32") da[0, :, ::2] = 100 return da if style == "Horizontal": return np.array([[np.ones((28))*random.randint(0, 255) for x in range(28)]], dtype="float32") if style == "HorizontalOneLine": da = np.zeros((1, 28, 28)).astype("float32") da[0, 14, :] = 100 return da if style == "HorizontalZebra": da = np.zeros((1, 28, 28)).astype("float32") da[0, ::2, :] = 100 return da return np.random.randint(0, 255, (1, 28, 28)).astype("float32") class MyDataset(Dataset): def __init__(self, size, transform=None, target_transform=None, loader=default_loader): imgs = [] for line in range(size): if random.randint(0, 1) == 0: imgs.append(("0", int(0))) else: imgs.append(("1", int(1))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) img = torch.from_numpy(img) return img, label def __len__(self): return len(self.imgs) train_data = MyDataset(size=50000, transform=transforms.Compose([ transforms.ColorJitter(0.2, 0.2), transforms.RandomRotation(30), transforms.RandomResizedCrop(28), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) test_data = MyDataset(size=10000, transform=transforms.Compose([ transforms.ColorJitter(0.2, 0.2), transforms.RandomRotation(30), transforms.RandomResizedCrop(28), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_loader = torch.utils.data.DataLoader(train_data, batch_size=batchsize, shuffle=shuffle, drop_last=True, num_workers=num_workers, #collate_fn = collate_fn ) test_loader = torch.utils.data.DataLoader(test_data, batch_size=batchsize, shuffle=shuffle, drop_last=True, num_workers=num_workers, #collate_fn = collate_fn ) return train_loader, test_loader