witnn/Other/UnSupperviseSelfData.py

186 lines
5.5 KiB
Python
Raw Normal View History

2019-08-19 15:53:10 +08:00
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
from visdom import Visdom
# viz=Visdom()
# viz.delete_env('main')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def default_loader(path):
da=np.random.randint(0,255,(1,28,28)).astype("float32")
da[0,15:17,15:17]=255
return da
class MyDataset(Dataset):
def __init__(self,imagepath, transform=None, target_transform=None, loader=default_loader):
imgs = []
for line in range(10000):
imgs.append((imagepath,int(0)))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
img = torch.from_numpy(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data=MyDataset(imagepath="" , transform=transforms.Compose([
#transforms.Resize(256),
#transforms.CenterCrop(224),
# transforms.RandomHorizontalFlip(),
# transforms.RandomAffine(degrees=30,translate=(0.2,0.2),scale=(0.8,1.2),resample=PIL.Image.BILINEAR,fillcolor=0),
#transforms.ColorJitter(),
transforms.ToTensor(),
#transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)),
]))
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=64,
shuffle=True,#if random data
drop_last=True,
num_workers=1,
#collate_fn = collate_fn
)
# # Training dataset
# train_loader = torch.utils.data.DataLoader(
# datasets.CIFAR10(root='.', train=True, download=True,
# transform=transforms.Compose([
# transforms.ToTensor(),
# #transforms.Normalize((0.1307,), (0.3081,))
# ])), batch_size=128, shuffle=True, num_workers=1)
# # Test dataset
# val_loader = torch.utils.data.DataLoader(
# datasets.CIFAR10(root='.', train=False, transform=transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,))
# ])), batch_size=32, shuffle=True, num_workers=1)
class NetMnist(nn.Module):
def __init__(self):
super(NetMnist, self).__init__()
channels=1
self.conv1 = nn.Conv2d(1, channels, kernel_size=3 , padding=0)
def forward(self, x):
da = self.conv1.weight.data
da = da.view(9)
damean=da.mean()
da = da - damean
daabssum=da.abs().sum()
da = da/daabssum
da = da.view(1,1,3,3)
self.conv1.weight.data = da
con1 = self.conv1(x)
con1 = con1.abs()
# con1 = F.sigmoid(F.max_pool2d(self.conv1(x), 2))
#
# con2 = F.sigmoid(F.max_pool2d(self.conv2(con1), 2))
#
# con3 = F.sigmoid(F.max_pool2d((self.conv3(con2)),2))
#
# con4 = F.sigmoid(self.conv4(con3))
#
# x = con4.view(-1,10)
return con1
model = (NetMnist()).to(device)
#########################################################
optimizer = optim.SGD(model.parameters(), lr=0.1)
#lossfunc=torch.nn.CrossEntropyLoss()
lossfunc=torch.nn.MSELoss()
gpu_ids=[0,1,2,3]
#model = torch.nn.DataParallel(model, device_ids = gpu_ids)
#optimizer = torch.nn.DataParallel(optimizer, device_ids = gpu_ids)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
target = output + 0.1
var_no_grad = target.detach()
loss = lossfunc(output, var_no_grad)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0 and batch_idx>0 :
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))
da=model.conv1.weight.data
da = da.view(9)
damean = da.mean()
da = da - damean
daabssum = da.abs().sum()
da = da / daabssum
da = da.view(1, 1, 3, 3)
print(da)
def val():
with torch.no_grad():
model.eval()
correct = 0
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
lens=len(val_loader.dataset)
correct = float(correct)/lens
print('\nTest set: val: {:.6f}\n'.format(correct))
for epoch in range(1, 3000):
train(epoch)
#val()