witnn/PaintVgg19/PaintVgg19.py

154 lines
5.2 KiB
Python
Raw Normal View History

2019-08-19 15:53:10 +08:00
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
from visdom import Visdom
import cv2
import os
import shutil
from vgg19Pytorch import Vgg19Module
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
print("Current Path :" + CurrentPath)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# vgg19=Vgg19Module(CurrentPath+'/vgg19Pytorch.npy')
vgg19 = torchvision.models.vgg19(True)
if torch.cuda.is_available():
vgg19=vgg19.to(device)
def readAndPreprocessImage(imagepath):
mean = np.array([103.939, 116.779, 123.68])
img=cv2.imread(imagepath)
img=img.astype('float32')
img -= mean
img = cv2.resize(img, (224, 224)).transpose((2, 0, 1))
img = img[np.newaxis, :, :, :]
# cv image read is by BGR order
# and the network Need RGB
# the following BGR values should be subtracted: [103.939, 116.779, 123.68].
image = torch.from_numpy(img)
if torch.cuda.is_available():
image = image.cuda()
return image
imgdir = CurrentPath+'/../Dataset/ILSVRC2012_img_val/'
allImageData=[]
allfile=[]
for dirpath, dirnames, filenames in os.walk(imgdir):
for file in filenames[0:1000]:
allImageData.append(dirpath+"/"+file)
allfile.append(file)
resultdata=[]
count=0
for imagefile in allImageData:
count+=1
image=readAndPreprocessImage(imagefile)
conv1_1_pad = F.pad(image, (1L, 1L, 1L, 1L))
conv1_1 = vgg19.conv1_1(conv1_1_pad)
relu1_1 = F.relu(conv1_1)
conv1_2_pad = F.pad(relu1_1, (1L, 1L, 1L, 1L))
conv1_2 = vgg19.conv1_2(conv1_2_pad)
relu1_2 = F.relu(conv1_2)
pool1 = F.max_pool2d(relu1_2, kernel_size=(2L, 2L), stride=(2L, 2L), padding=(0L,), ceil_mode=True)
conv2_1_pad = F.pad(pool1, (1L, 1L, 1L, 1L))
conv2_1 = vgg19.conv2_1(conv2_1_pad)
relu2_1 = F.relu(conv2_1)
conv2_2_pad = F.pad(relu2_1, (1L, 1L, 1L, 1L))
conv2_2 = vgg19.conv2_2(conv2_2_pad)
relu2_2 = F.relu(conv2_2)
pool2 = F.max_pool2d(relu2_2, kernel_size=(2L, 2L), stride=(2L, 2L), padding=(0L,), ceil_mode=True)
conv3_1_pad = F.pad(pool2, (1L, 1L, 1L, 1L))
conv3_1 = vgg19.conv3_1(conv3_1_pad)
relu3_1 = F.relu(conv3_1)
conv3_2_pad = F.pad(relu3_1, (1L, 1L, 1L, 1L))
conv3_2 = vgg19.conv3_2(conv3_2_pad)
relu3_2 = F.relu(conv3_2)
conv3_3_pad = F.pad(relu3_2, (1L, 1L, 1L, 1L))
conv3_3 = vgg19.conv3_3(conv3_3_pad)
relu3_3 = F.relu(conv3_3)
conv3_4_pad = F.pad(relu3_3, (1L, 1L, 1L, 1L))
conv3_4 = vgg19.conv3_4(conv3_4_pad)
relu3_4 = F.relu(conv3_4)
pool3 = F.max_pool2d(relu3_4, kernel_size=(2L, 2L), stride=(2L, 2L), padding=(0L,), ceil_mode=True)
conv4_1_pad = F.pad(pool3, (1L, 1L, 1L, 1L))
conv4_1 = vgg19.conv4_1(conv4_1_pad)
relu4_1 = F.relu(conv4_1)
conv4_2_pad = F.pad(relu4_1, (1L, 1L, 1L, 1L))
conv4_2 = vgg19.conv4_2(conv4_2_pad)
relu4_2 = F.relu(conv4_2)
conv4_3_pad = F.pad(relu4_2, (1L, 1L, 1L, 1L))
conv4_3 = vgg19.conv4_3(conv4_3_pad)
relu4_3 = F.relu(conv4_3)
conv4_4_pad = F.pad(relu4_3, (1L, 1L, 1L, 1L))
conv4_4 = vgg19.conv4_4(conv4_4_pad)
relu4_4 = F.relu(conv4_4)
pool4 = F.max_pool2d(relu4_4, kernel_size=(2L, 2L), stride=(2L, 2L), padding=(0L,), ceil_mode=True)
conv5_1_pad = F.pad(pool4, (1L, 1L, 1L, 1L))
conv5_1 = vgg19.conv5_1(conv5_1_pad)
relu5_1 = F.relu(conv5_1)
conv5_2_pad = F.pad(relu5_1, (1L, 1L, 1L, 1L))
conv5_2 = vgg19.conv5_2(conv5_2_pad)
relu5_2 = F.relu(conv5_2)
conv5_3_pad = F.pad(relu5_2, (1L, 1L, 1L, 1L))
conv5_3 = vgg19.conv5_3(conv5_3_pad)
relu5_3 = F.relu(conv5_3)
conv5_4_pad = F.pad(relu5_3, (1L, 1L, 1L, 1L))
conv5_4 = vgg19.conv5_4(conv5_4_pad)
resultdata.append(conv5_3.detach().cpu().numpy()[0,204,:,:])
print("CalImageCount:"+str(count))
his=np.histogram(resultdata,256)
xx=np.array(his[0]).astype('float32')
yy=np.array(his[1]).astype('float32')
data=np.empty((2,256))
data[1,:]=xx
data[0,:]=yy[0:256]
data=data.swapaxes(0,1)
threvalue=500
count=0
for daimage in resultdata:
if np.max(daimage)>threvalue:
path1 = allImageData[count]
ima=cv2.imread(path1)
imagewidth=ima.shape[1]
imageheight=ima.shape[0]
datawidth=daimage.shape[1]
dataheight=daimage.shape[0]
widthradio=float(imagewidth)/float(datawidth)
heithradio=float(imageheight)/float(dataheight)
for i in range(dataheight):
for j in range(datawidth):
if daimage[i,j]>threvalue :
cv2.circle(ima,(int(j*widthradio),int(i*heithradio)),15,(255,0,0),2);
path2 = CurrentPath+"/imagePointOut/" + allfile[count]
cv2.imwrite(path2,ima)
count+=1