witnn/tools/Loader.py

142 lines
6.5 KiB
Python

import os
import torch
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import numpy as np
from torch.utils.data import Dataset, DataLoader, Subset
import random
def MNIST(batchsize=8, num_workers=0, shuffle=False):
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root=CurrentPath+'../Dataset/', train=True, download=True,
transform=transforms.Compose([
transforms.ColorJitter(0.2, 0.2),
transforms.RandomRotation(30),
transforms.RandomResizedCrop(28),
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
return train_loader, test_loader
def Cifar10(batchsize=8, num_workers=0, shuffle=False):
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
return train_loader, test_loader
def Cifar10Mono(batchsize=8, num_workers=0, shuffle=False, trainsize=0):
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
dataset = datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=True, download=True,
transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
if trainsize == 0:
trainsize = dataset.data.shape[0]
train_loader = torch.utils.data.DataLoader(
Subset(dataset, range(0, trainsize)), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
return train_loader, test_loader
def RandomMnist(batchsize=8, num_workers=0, shuffle=False, style=""):
def default_loader(path):
if path == "0":
return np.random.randint(0, 255, (1, 28, 28)).astype("float32")
if style == "Vertical":
da = np.random.randint(0, 255, (28)).astype("float32")
return np.array([[da for x in range(28)]], dtype="float32")
if style == "VerticalOneLine":
da = np.zeros((1, 28, 28)).astype("float32")
da[0, :, 14] = 100
return da
if style == "VerticalZebra":
da = np.zeros((1, 28, 28)).astype("float32")
da[0, :, ::2] = 100
return da
if style == "Horizontal":
return np.array([[np.ones((28))*random.randint(0, 255) for x in range(28)]], dtype="float32")
if style == "HorizontalOneLine":
da = np.zeros((1, 28, 28)).astype("float32")
da[0, 14, :] = 100
return da
if style == "HorizontalZebra":
da = np.zeros((1, 28, 28)).astype("float32")
da[0, ::2, :] = 100
return da
return np.random.randint(0, 255, (1, 28, 28)).astype("float32")
class MyDataset(Dataset):
def __init__(self, size, transform=None, target_transform=None, loader=default_loader):
imgs = []
for line in range(size):
if random.randint(0, 1) == 0:
imgs.append(("0", int(0)))
else:
imgs.append(("1", int(1)))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
img = torch.from_numpy(img)
return img, label
def __len__(self):
return len(self.imgs)
train_data = MyDataset(size=50000, transform=transforms.Compose([
transforms.ColorJitter(0.2, 0.2),
transforms.RandomRotation(30),
transforms.RandomResizedCrop(28),
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
test_data = MyDataset(size=10000, transform=transforms.Compose([
transforms.ColorJitter(0.2, 0.2),
transforms.RandomRotation(30),
transforms.RandomResizedCrop(28),
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=batchsize,
shuffle=shuffle,
drop_last=True,
num_workers=num_workers,
#collate_fn = collate_fn
)
test_loader = torch.utils.data.DataLoader(test_data,
batch_size=batchsize,
shuffle=shuffle,
drop_last=True,
num_workers=num_workers,
#collate_fn = collate_fn
)
return train_loader, test_loader