259 lines
7.9 KiB
Python
259 lines
7.9 KiB
Python
|
|
||
|
from __future__ import print_function
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch.optim as optim
|
||
|
import torchvision
|
||
|
from torchvision import datasets, transforms
|
||
|
import torchvision.models as models
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
from visdom import Visdom
|
||
|
|
||
|
|
||
|
viz=Visdom()
|
||
|
viz.delete_env('main')
|
||
|
|
||
|
imagewin1=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
|
||
|
imagewin2=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
|
||
|
imagewin3=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
|
||
|
imagewin4=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
|
||
|
imagewin5=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
|
||
|
imagewin6=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
|
||
|
|
||
|
|
||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
|
||
|
# Training dataset
|
||
|
train_loader = torch.utils.data.DataLoader(
|
||
|
datasets.MNIST(root='.', train=True, download=True,
|
||
|
transform=transforms.Compose([
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize((0.1307,), (0.3081,))
|
||
|
])), batch_size=64, shuffle=True, num_workers=4)
|
||
|
# Test dataset
|
||
|
test_loader = torch.utils.data.DataLoader(
|
||
|
datasets.MNIST(root='.', train=False, transform=transforms.Compose([
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize((0.1307,), (0.3081,))
|
||
|
])), batch_size=64, shuffle=True, num_workers=4)
|
||
|
|
||
|
|
||
|
|
||
|
#modelnet = models.vgg11(pretrained=True)
|
||
|
|
||
|
class Net(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(Net, self).__init__()
|
||
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||
|
self.conv2_drop = nn.Dropout2d()
|
||
|
self.fc1 = nn.Linear(320, 50)
|
||
|
self.fc2 = nn.Linear(50, 10)
|
||
|
|
||
|
# Spatial transformer localization-network
|
||
|
self.localization = nn.Sequential(
|
||
|
nn.Conv2d(1, 8, kernel_size=7),
|
||
|
nn.MaxPool2d(2, stride=2),
|
||
|
nn.ReLU(True),
|
||
|
nn.Conv2d(8, 10, kernel_size=5),
|
||
|
nn.MaxPool2d(2, stride=2),
|
||
|
nn.ReLU(True)
|
||
|
)
|
||
|
|
||
|
# Regressor for the 3 * 2 affine matrix
|
||
|
self.fc_loc = nn.Sequential(
|
||
|
nn.Linear(10 * 3 * 3, 32),
|
||
|
nn.ReLU(True),
|
||
|
nn.Linear(32, 3 * 2)
|
||
|
)
|
||
|
|
||
|
# Initialize the weights/bias with identity transformation
|
||
|
self.fc_loc[2].weight.data.zero_()
|
||
|
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
|
||
|
|
||
|
# Spatial transformer network forward function
|
||
|
def stn(self, x):
|
||
|
xs = self.localization(x)
|
||
|
xs = xs.view(-1, 10 * 3 * 3)
|
||
|
theta = self.fc_loc(xs)
|
||
|
theta = theta.view(-1, 2, 3)
|
||
|
|
||
|
grid = F.affine_grid(theta, x.size())
|
||
|
x = F.grid_sample(x, grid)
|
||
|
|
||
|
return x
|
||
|
|
||
|
def forward(self, x):
|
||
|
# transform the input
|
||
|
x = self.stn(x)
|
||
|
|
||
|
# Perform the usual forward pass
|
||
|
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||
|
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||
|
x = x.view(-1, 320)
|
||
|
x = F.relu(self.fc1(x))
|
||
|
x = F.dropout(x, training=self.training)
|
||
|
x = self.fc2(x)
|
||
|
return F.log_softmax(x, dim=1)
|
||
|
class NetMnist(nn.Module):
|
||
|
def __init__(self):
|
||
|
super(NetMnist, self).__init__()
|
||
|
self.conv1 = nn.Conv2d(1, 4, kernel_size=5)
|
||
|
self.conv2 = nn.Conv2d(4, 8, kernel_size=3)
|
||
|
self.conv3 = nn.Conv2d(8, 16, kernel_size=5)
|
||
|
self.fc1 = nn.Linear(1*16, 10)
|
||
|
|
||
|
def forward(self, x):
|
||
|
|
||
|
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||
|
x = F.relu(F.max_pool2d(self.conv2(x), 2))
|
||
|
x = F.relu(self.conv3(x), 2)
|
||
|
|
||
|
x = x.view(-1, 1*16)
|
||
|
x = F.relu(self.fc1(x))
|
||
|
|
||
|
return F.log_softmax(x, dim=1)
|
||
|
|
||
|
|
||
|
#model = torch.nn.DataParallel(NetMnist()).to(device)
|
||
|
model = (NetMnist()).to(device)
|
||
|
|
||
|
|
||
|
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
||
|
|
||
|
|
||
|
|
||
|
ad = None;
|
||
|
|
||
|
def train(epoch):
|
||
|
model.train()
|
||
|
for batch_idx, (data, target) in enumerate(train_loader):
|
||
|
data, target = data.to(device), target.to(device)
|
||
|
optimizer.zero_grad()
|
||
|
global ad
|
||
|
|
||
|
if ad is None:
|
||
|
ad=data
|
||
|
|
||
|
# nda=data.numpy()
|
||
|
# nao=np.append(nda,nda,1)
|
||
|
# nao = np.append(nao, nda, 1)
|
||
|
# data=torch.tensor(nao)
|
||
|
|
||
|
|
||
|
output = model(data)
|
||
|
loss = F.nll_loss(output, target)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
if batch_idx % 930 == 0 and batch_idx>0 :
|
||
|
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
|
||
|
|
||
|
|
||
|
def test():
|
||
|
with torch.no_grad():
|
||
|
model.eval()
|
||
|
test_loss = 0
|
||
|
correct = 0
|
||
|
for data, target in test_loader:
|
||
|
data, target = data.to(device), target.to(device)
|
||
|
output = model(data)
|
||
|
|
||
|
# sum up batch loss
|
||
|
test_loss += F.nll_loss(output, target, size_average=False).item()
|
||
|
# get the index of the max log-probability
|
||
|
pred = output.max(1, keepdim=True)[1]
|
||
|
correct += pred.eq(target.view_as(pred)).sum().item()
|
||
|
|
||
|
test_loss /= len(test_loader.dataset)
|
||
|
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
|
||
|
|
||
|
imaglis1=[]
|
||
|
imaglis2=[]
|
||
|
imaglis3=[]
|
||
|
imaglis4=[]
|
||
|
imaglis5=[]
|
||
|
imaglis6=[]
|
||
|
|
||
|
for epoch in range(1, 100 + 1):
|
||
|
train(epoch)
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
mm1 = model.conv1(ad[0:1, :, :, :])
|
||
|
datodis = mm1 * 256 + 128
|
||
|
dashape = datodis.shape
|
||
|
datodis = datodis.view(dashape[1], 1, dashape[2], dashape[3])
|
||
|
imaglis1.append(datodis.detach().cpu().numpy())
|
||
|
|
||
|
imaglisnp = np.array(imaglis1)
|
||
|
cdsa = np.reshape(imaglisnp, newshape=[-1, 1, imaglisnp.shape[3], imaglisnp.shape[4]])
|
||
|
viz.images(cdsa, win=imagewin1, opts=dict(title='Random!', caption='How random.'), nrow=dashape[1])
|
||
|
|
||
|
|
||
|
mm2 = model.conv2(F.relu(F.max_pool2d(mm1, 2)))
|
||
|
datodis = mm2 * 256 + 128
|
||
|
dashape = datodis.shape
|
||
|
datodis = datodis.view(dashape[1], 1, dashape[2], dashape[3])
|
||
|
imaglis2.append(datodis.detach().cpu().numpy())
|
||
|
|
||
|
imaglisnp = np.array(imaglis2)
|
||
|
cdsa = np.reshape(imaglisnp, newshape=[-1, 1, imaglisnp.shape[3], imaglisnp.shape[4]])
|
||
|
viz.images(cdsa, win=imagewin2, opts=dict(title='Random!', caption='How random.'), nrow=dashape[1])
|
||
|
|
||
|
|
||
|
mm3 = model.conv3(F.relu(F.max_pool2d(mm2, 2)))
|
||
|
datodis = mm3 * 256 + 128
|
||
|
dashape = datodis.shape
|
||
|
datodis = datodis.view(dashape[1], 1, dashape[2], dashape[3])
|
||
|
imaglis3.append(datodis.detach().cpu().numpy())
|
||
|
|
||
|
imaglisnp = np.array(imaglis3)
|
||
|
cdsa = np.reshape(imaglisnp, newshape=[-1, 1, imaglisnp.shape[3], imaglisnp.shape[4]])
|
||
|
viz.images(cdsa, win=imagewin3, opts=dict(title='Random!', caption='How random.'), nrow=dashape[1])
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
mm4 = model.conv1.weight.data
|
||
|
datodis = mm4 * 256 + 128
|
||
|
dashape = datodis.shape
|
||
|
datodis = datodis.view(dashape[1]*dashape[0], 1, dashape[2], dashape[3])
|
||
|
imaglis4.append(datodis.detach().cpu().numpy())
|
||
|
|
||
|
imaglisnp = np.array(imaglis4)
|
||
|
cdsa = np.reshape(imaglisnp, newshape=[-1, 1, imaglisnp.shape[3], imaglisnp.shape[4]])
|
||
|
viz.images(cdsa, win=imagewin4, opts=dict(title='Random!', caption='How random.'), nrow=dashape[1]*dashape[0])
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
# mm = model.conv1(ad)
|
||
|
# datodis = mm * 256 + 128
|
||
|
# datodis = datodis.view(64, 1, 24 * 10, 24)
|
||
|
# imaglis.append(datodis.detach().cpu().numpy()[0:8,:,:,:])
|
||
|
|
||
|
test()
|
||
|
|
||
|
|
||
|
#cdsa=np.reshape(np.array(imaglis),newshape=[-1,1,240,24])
|
||
|
#viz.images(cdsa, win=imagewin, opts=dict(title='Random!', caption='How random.'), nrow=8,padding=2)
|
||
|
|
||
|
|
||
|
|
||
|
# Visualize the STN transformation on some input batch
|
||
|
# visualize_stn()
|
||
|
|
||
|
#plt.ioff()
|
||
|
#plt.show()
|