witnn/VisualNetwork/VisualVgg19.py

215 lines
6.2 KiB
Python

from __future__ import print_function
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
import cv2
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
print("Current Path :" + CurrentPath)
image_out_path=CurrentPath+"/imageoutVgg19/"
if not os.path.exists(image_out_path):
os.mkdir(image_out_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#region define network myself
# class Net(nn.Module):
# def __init__(self):
# super(Net, self).__init__()
# self.conv1 = nn.Conv2d(1, 4, kernel_size=5)
# self.conv2 = nn.Conv2d(4, 8, kernel_size=3)
# self.conv3 = nn.Conv2d(8, 16, kernel_size=5)
# self.fc1 = nn.Linear(1*16, 10)
#
# def forward(self, x):
#
# x = F.relu(F.max_pool2d(self.conv1(x), 2))
# x = F.relu(F.max_pool2d(self.conv2(x), 2))
# x = F.relu(self.conv3(x), 2)
#
# x = x.view(-1, 1*16)
# x = F.relu(self.fc1(x))
#
# return F.log_softmax(x, dim=1)
#
# netmodel = Net()
#
# state_dict = torch.load("mnistcnn.pth.tar", map_location='cpu')
# from collections import OrderedDict
# new_state_dict = OrderedDict()
# for k, v in state_dict.items():
# name = k
# if k[0:7] == "module.":
# name = k[7:]
# new_state_dict[name] = v
# netmodel.load_state_dict(new_state_dict)
# netmodel.eval()
# train_loader = torch.utils.data.DataLoader(
# datasets.MNIST(root='.', train=True, download=True,
# transform=transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,))
# ])), batch_size=1, shuffle=True, num_workers=4)
# test_loader = torch.utils.data.DataLoader(
# datasets.MNIST(root='.', train=False, transform=transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,))
# ])), batch_size=1, shuffle=True, num_workers=4)
#
# for batch_idx, (data, target) in enumerate(train_loader):
# data, target = data.to(device), target.to(device)
#
# output = netmodel(data)
# i=0
#endregion define network myself
#region define from torchvision models
netmodel = torchvision.models.vgg19(True)
netmodel.eval()
# img = cv2.imread("t.jpg")
# b,g,r = cv2.split(img)
# img = cv2.merge([r,g,b])
# img = img.astype("float32")
# img = np.transpose(img,(2,0,1))
# img = torch.from_numpy(img)
#
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# img = img/256
# img = normalize(img)
#
# img = img.view(1,3,224,224)
# out = netmodel(img)
# out = out.view(-1).detach().numpy()
# index = np.argmax(out)
# value = out[index]
# i=0
#endregion define from torchvision models
def visualmodle(initimagefile,netmodel,layer,channel):
class Suggest(nn.Module):
def __init__(self, initdata=None):
super(Suggest, self).__init__()
# self.weight = nn.Parameter(torch.randn((1,1,28,28)))
if initdata is not None:
self.weight = nn.Parameter(initdata)
else:
data = np.random.uniform(-1, 1, (1, 3, 224, 224))
data = data.astype("float32")
data = torch.from_numpy(data)
self.weight = nn.Parameter(data)
def forward(self, x):
x = x * self.weight
return F.upsample(x, (224, 224), mode='bilinear', align_corners=True)
netmodel.eval()
if initimagefile is None:
model = Suggest(None)
else:
img = cv2.imread(initimagefile)
b, g, r = cv2.split(img)
img = cv2.merge([r, g, b])
img = img.astype("float32")
img = np.transpose(img, (2, 0, 1))
img = torch.from_numpy(img)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
img = img / 256
img = normalize(img)
model = Suggest(img)
optimizer = optim.SGD(model.parameters(), lr= 1.0)
model.train()
data = np.ones((1,3,224,224), dtype="float32")
data = torch.from_numpy(data)
# target = np.zeros((1),dtype='int64')
# target[0]=14
# target = torch.from_numpy(target)
# criterion = nn.CrossEntropyLoss()
criterion = nn.MSELoss()
if torch.cuda.is_available():
criterion = criterion.cuda()
model = model.cuda()
netmodel = netmodel.cuda()
data = data.cuda()
for i in range(100):
output = model(data)
netout=[]
netint=[]
def getnet(self, input, output):
netout.append(output)
netint.append(input)
# print(netmodel.features)
handle = netmodel.features[layer].register_forward_hook(getnet)
output = netmodel(output)
output = netout[0][0,channel,:,:]
netout=[]
netint=[]
# output = output.mean()
target = output+256.0
target = target.detach()
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Train Inter:'+str(i) + " loss:"+str(loss.cpu().detach().numpy()))
handle.remove()
# model = model.cpu()
# netmodel = netmodel.cpu()
# data = data.cpu()
model.eval()
output = model(data)
out = output.view(3,224,224)
out = out.cpu().detach().numpy()
outmax = out[0].max()
outmin = out[0].min()
out[0] = out[0] * (256.0/(outmax-outmin)) - outmin * (256.0/(outmax-outmin))
outmax = out[1].max()
outmin = out[1].min()
out[1] = out[1] * (256.0/(outmax-outmin)) - outmin * (256.0/(outmax-outmin))
outmax = out[2].max()
outmin = out[2].min()
out[2] = out[2] * (256.0/(outmax-outmin)) - outmin * (256.0/(outmax-outmin))
out = np.transpose(out,(1,2,0))
b,g,r = cv2.split(out)
out = cv2.merge([r,g,b])
out = out*2
out = out- (128*(2-1))
return out
# 128 7
# 512 30
for i in range(512):
out = visualmodle(None,netmodel,36,i)
cv2.imwrite(image_out_path+"L36_C"+str(i)+".jpg",out)
i=0