witnn/Other/transformer_tutorial.py

259 lines
7.9 KiB
Python
Raw Normal View History

2019-08-19 15:53:10 +08:00
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
from visdom import Visdom
viz=Visdom()
viz.delete_env('main')
imagewin1=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
imagewin2=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
imagewin3=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
imagewin4=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
imagewin5=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
imagewin6=viz.image(np.random.rand(3, 512, 256),opts=dict(title='Random!', caption='How random.'))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Training dataset
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=64, shuffle=True, num_workers=4)
# Test dataset
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root='.', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=64, shuffle=True, num_workers=4)
#modelnet = models.vgg11(pretrained=True)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
def forward(self, x):
# transform the input
x = self.stn(x)
# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class NetMnist(nn.Module):
def __init__(self):
super(NetMnist, self).__init__()
self.conv1 = nn.Conv2d(1, 4, kernel_size=5)
self.conv2 = nn.Conv2d(4, 8, kernel_size=3)
self.conv3 = nn.Conv2d(8, 16, kernel_size=5)
self.fc1 = nn.Linear(1*16, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2(x), 2))
x = F.relu(self.conv3(x), 2)
x = x.view(-1, 1*16)
x = F.relu(self.fc1(x))
return F.log_softmax(x, dim=1)
#model = torch.nn.DataParallel(NetMnist()).to(device)
model = (NetMnist()).to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
ad = None;
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
global ad
if ad is None:
ad=data
# nda=data.numpy()
# nao=np.append(nda,nda,1)
# nao = np.append(nao, nda, 1)
# data=torch.tensor(nao)
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 930 == 0 and batch_idx>0 :
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
def test():
with torch.no_grad():
model.eval()
test_loss = 0
correct = 0
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))
imaglis1=[]
imaglis2=[]
imaglis3=[]
imaglis4=[]
imaglis5=[]
imaglis6=[]
for epoch in range(1, 100 + 1):
train(epoch)
mm1 = model.conv1(ad[0:1, :, :, :])
datodis = mm1 * 256 + 128
dashape = datodis.shape
datodis = datodis.view(dashape[1], 1, dashape[2], dashape[3])
imaglis1.append(datodis.detach().cpu().numpy())
imaglisnp = np.array(imaglis1)
cdsa = np.reshape(imaglisnp, newshape=[-1, 1, imaglisnp.shape[3], imaglisnp.shape[4]])
viz.images(cdsa, win=imagewin1, opts=dict(title='Random!', caption='How random.'), nrow=dashape[1])
mm2 = model.conv2(F.relu(F.max_pool2d(mm1, 2)))
datodis = mm2 * 256 + 128
dashape = datodis.shape
datodis = datodis.view(dashape[1], 1, dashape[2], dashape[3])
imaglis2.append(datodis.detach().cpu().numpy())
imaglisnp = np.array(imaglis2)
cdsa = np.reshape(imaglisnp, newshape=[-1, 1, imaglisnp.shape[3], imaglisnp.shape[4]])
viz.images(cdsa, win=imagewin2, opts=dict(title='Random!', caption='How random.'), nrow=dashape[1])
mm3 = model.conv3(F.relu(F.max_pool2d(mm2, 2)))
datodis = mm3 * 256 + 128
dashape = datodis.shape
datodis = datodis.view(dashape[1], 1, dashape[2], dashape[3])
imaglis3.append(datodis.detach().cpu().numpy())
imaglisnp = np.array(imaglis3)
cdsa = np.reshape(imaglisnp, newshape=[-1, 1, imaglisnp.shape[3], imaglisnp.shape[4]])
viz.images(cdsa, win=imagewin3, opts=dict(title='Random!', caption='How random.'), nrow=dashape[1])
mm4 = model.conv1.weight.data
datodis = mm4 * 256 + 128
dashape = datodis.shape
datodis = datodis.view(dashape[1]*dashape[0], 1, dashape[2], dashape[3])
imaglis4.append(datodis.detach().cpu().numpy())
imaglisnp = np.array(imaglis4)
cdsa = np.reshape(imaglisnp, newshape=[-1, 1, imaglisnp.shape[3], imaglisnp.shape[4]])
viz.images(cdsa, win=imagewin4, opts=dict(title='Random!', caption='How random.'), nrow=dashape[1]*dashape[0])
# mm = model.conv1(ad)
# datodis = mm * 256 + 128
# datodis = datodis.view(64, 1, 24 * 10, 24)
# imaglis.append(datodis.detach().cpu().numpy()[0:8,:,:,:])
test()
#cdsa=np.reshape(np.array(imaglis),newshape=[-1,1,240,24])
#viz.images(cdsa, win=imagewin, opts=dict(title='Random!', caption='How random.'), nrow=8,padding=2)
# Visualize the STN transformation on some input batch
# visualize_stn()
#plt.ioff()
#plt.show()