witnn/tools/Loader.py

144 lines
6.4 KiB
Python

# from __future__ import print_function
import os
import torch
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import numpy as np
from torch.utils.data import Dataset, DataLoader
import random
def MNIST(batchsize=8, num_workers=0, shuffle=False):
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root=CurrentPath+'../Dataset/', train=True, download=True,
transform=transforms.Compose([
transforms.ColorJitter(0.2, 0.2),
transforms.RandomRotation(30),
transforms.RandomResizedCrop(28),
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
return train_loader, test_loader
def Cifar10(batchsize=8, num_workers=0, shuffle=False):
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
return train_loader, test_loader
def Cifar10Mono(batchsize=8, num_workers=0, shuffle=False):
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=True, download=True,
transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root=CurrentPath+'../Dataset/', train=False, transform=transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=batchsize, shuffle=shuffle, num_workers=num_workers, drop_last=True)
return train_loader, test_loader
def RandomMnist(batchsize=8, num_workers=0, shuffle=False, style=""):
def default_loader(path):
if path == "0":
return np.random.randint(0, 255, (1, 28, 28)).astype("float32")
if style == "Vertical":
da = np.random.randint(0, 255, (28)).astype("float32")
return np.array([[da for x in range(28)]], dtype="float32")
if style == "VerticalOneLine":
da = np.zeros((1, 28, 28)).astype("float32")
da[0, :, 14] = 100
return da
if style == "VerticalZebra":
da = np.zeros((1, 28, 28)).astype("float32")
da[0, :, ::2] = 100
return da
if style == "Horizontal":
return np.array([[np.ones((28))*random.randint(0, 255) for x in range(28)]], dtype="float32")
if style == "HorizontalOneLine":
da = np.zeros((1, 28, 28)).astype("float32")
da[0, 14, :] = 100
return da
if style == "HorizontalZebra":
da = np.zeros((1, 28, 28)).astype("float32")
da[0, ::2, :] = 100
return da
return np.random.randint(0, 255, (1, 28, 28)).astype("float32")
class MyDataset(Dataset):
def __init__(self, size, transform=None, target_transform=None, loader=default_loader):
imgs = []
for line in range(size):
if random.randint(0, 1) == 0:
imgs.append(("0", int(0)))
else:
imgs.append(("1", int(1)))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
img = torch.from_numpy(img)
return img, label
def __len__(self):
return len(self.imgs)
train_data = MyDataset(size=50000, transform=transforms.Compose([
transforms.ColorJitter(0.2, 0.2),
transforms.RandomRotation(30),
transforms.RandomResizedCrop(28),
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
test_data = MyDataset(size=10000, transform=transforms.Compose([
transforms.ColorJitter(0.2, 0.2),
transforms.RandomRotation(30),
transforms.RandomResizedCrop(28),
transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=batchsize,
shuffle=shuffle,
drop_last=True,
num_workers=num_workers,
#collate_fn = collate_fn
)
test_loader = torch.utils.data.DataLoader(test_data,
batch_size=batchsize,
shuffle=shuffle,
drop_last=True,
num_workers=num_workers,
#collate_fn = collate_fn
)
return train_loader, test_loader