158 lines
7.2 KiB
Python
158 lines
7.2 KiB
Python
import os
|
|
import torch
|
|
import torchvision
|
|
from torchvision import datasets, transforms
|
|
import torchvision.models as models
|
|
import numpy as np
|
|
from torch.utils.data import Dataset, DataLoader, Subset
|
|
import random
|
|
|
|
|
|
def MNIST(batchsize=8, num_workers=0, shuffle=False, trainsize=0, resize=0):
|
|
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
|
|
if resize == 0:
|
|
resize = 28
|
|
if shuffle:
|
|
trans = transforms.Compose([transforms.ColorJitter(0.2, 0.2),
|
|
transforms.RandomRotation(30),
|
|
transforms.RandomResizedCrop(28),
|
|
transforms.Resize(resize),
|
|
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
else:
|
|
trans = transforms.Compose(
|
|
[transforms.Resize(resize), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
traindata = datasets.MNIST(root=CurrentPath+'../Dataset/', train=True, download=True,
|
|
transform=trans)
|
|
if trainsize == 0:
|
|
trainsize = traindata.data.shape[0]
|
|
train_loader = torch.utils.data.DataLoader(
|
|
Subset(traindata, range(0, trainsize)),
|
|
batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
|
|
test_loader = torch.utils.data.DataLoader(
|
|
datasets.MNIST(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([
|
|
transforms.Resize(resize),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.1307,), (0.3081,))
|
|
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
|
|
print("Train Data size:"+str(trainsize)+" Shuffle:"+str(shuffle)+" BatchSize:"+str(batchsize))
|
|
try:
|
|
train_loader.batch_sampler.sampler.num_samples = trainsize
|
|
test_loader.batch_sampler.sampler.num_samples = 10000
|
|
finally:
|
|
return train_loader, test_loader
|
|
|
|
def Cifar10(batchsize=8, num_workers=0, shuffle=False):
|
|
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=True, download=True,
|
|
transform=transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.1307,), (0.3081,))
|
|
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
|
|
test_loader = torch.utils.data.DataLoader(
|
|
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.1307,), (0.3081,))
|
|
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
|
|
|
|
return train_loader, test_loader
|
|
|
|
def Cifar10Mono(batchsize=8, num_workers=0, shuffle=False, trainsize=0):
|
|
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
|
|
dataset = datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=True, download=True,
|
|
transform=transforms.Compose([
|
|
transforms.Grayscale(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.1307,), (0.3081,))
|
|
]))
|
|
if trainsize == 0:
|
|
trainsize = dataset.data.shape[0]
|
|
train_loader = torch.utils.data.DataLoader(
|
|
Subset(dataset, range(0, trainsize)), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
|
|
test_loader = torch.utils.data.DataLoader(
|
|
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([
|
|
transforms.Grayscale(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.1307,), (0.3081,))
|
|
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
|
|
return train_loader, test_loader
|
|
|
|
def RandomMnist(batchsize=8, num_workers=0, shuffle=False, style=""):
|
|
def default_loader(path):
|
|
if path == "0":
|
|
return np.random.randint(0, 255, (1, 28, 28)).astype("float32")
|
|
if style == "Vertical":
|
|
da = np.random.randint(0, 255, (28)).astype("float32")
|
|
return np.array([[da for x in range(28)]], dtype="float32")
|
|
if style == "VerticalOneLine":
|
|
da = np.zeros((1, 28, 28)).astype("float32")
|
|
da[0, :, 14] = 100
|
|
return da
|
|
if style == "VerticalZebra":
|
|
da = np.zeros((1, 28, 28)).astype("float32")
|
|
da[0, :, ::2] = 100
|
|
return da
|
|
if style == "Horizontal":
|
|
return np.array([[np.ones((28))*random.randint(0, 255) for x in range(28)]], dtype="float32")
|
|
if style == "HorizontalOneLine":
|
|
da = np.zeros((1, 28, 28)).astype("float32")
|
|
da[0, 14, :] = 100
|
|
return da
|
|
if style == "HorizontalZebra":
|
|
da = np.zeros((1, 28, 28)).astype("float32")
|
|
da[0, ::2, :] = 100
|
|
return da
|
|
return np.random.randint(0, 255, (1, 28, 28)).astype("float32")
|
|
|
|
class MyDataset(Dataset):
|
|
def __init__(self, size, transform=None, target_transform=None, loader=default_loader):
|
|
imgs = []
|
|
for line in range(size):
|
|
if random.randint(0, 1) == 0:
|
|
imgs.append(("0", int(0)))
|
|
else:
|
|
imgs.append(("1", int(1)))
|
|
self.imgs = imgs
|
|
self.transform = transform
|
|
self.target_transform = target_transform
|
|
self.loader = loader
|
|
|
|
def __getitem__(self, index):
|
|
fn, label = self.imgs[index]
|
|
img = self.loader(fn)
|
|
img = torch.from_numpy(img)
|
|
return img, label
|
|
|
|
def __len__(self):
|
|
return len(self.imgs)
|
|
|
|
train_data = MyDataset(size=50000, transform=transforms.Compose([
|
|
transforms.ColorJitter(0.2, 0.2),
|
|
transforms.RandomRotation(30),
|
|
transforms.RandomResizedCrop(28),
|
|
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
|
|
test_data = MyDataset(size=10000, transform=transforms.Compose([
|
|
transforms.ColorJitter(0.2, 0.2),
|
|
transforms.RandomRotation(30),
|
|
transforms.RandomResizedCrop(28),
|
|
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
|
|
|
|
train_loader = torch.utils.data.DataLoader(train_data,
|
|
batch_size=batchsize,
|
|
shuffle=shuffle,
|
|
drop_last=True,
|
|
num_workers=num_workers,
|
|
#collate_fn = collate_fn
|
|
)
|
|
test_loader = torch.utils.data.DataLoader(test_data,
|
|
batch_size=batchsize,
|
|
shuffle=shuffle,
|
|
drop_last=True,
|
|
num_workers=num_workers,
|
|
#collate_fn = collate_fn
|
|
)
|
|
return train_loader, test_loader
|
|
|
|
|