witnn/CNNDemo/Minist.py

74 lines
2.0 KiB
Python
Raw Normal View History

2020-07-13 20:14:01 +08:00
from __future__ import print_function
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
sys.path.append(CurrentPath+'../tools')
sys.path.append(CurrentPath+'../')
from tools import UniModule
from tools import utils, Train, Loader, WebVisual
class MINISTCNN(UniModule.ModuleBase):
def __init__(self):
super(MINISTCNN, self).__init__()
self.layer1 = nn.Sequential( # 28-4=24 /2=12
nn.Conv2d(1, 16, kernel_size=5),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential( # 12-2=10 /2=5
nn.Conv2d(16, 32, kernel_size=3),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer3 = nn.Sequential( # 5-2=3
nn.Conv2d(32, 64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.fc = nn.Sequential(
nn.Linear(64*3*3, 512),
# nn.ReLU(),
nn.Linear(512, 10))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return F.log_softmax(x, dim=1)
WebVisual.InitVisdom()
window = WebVisual.LineWin()
linePre = WebVisual.Line(window, "linePre")
traindata, testdata = Loader.MNIST(
32, num_workers=4, trainsize=60000, shuffle=True)
modelMinist = utils.SetDevice(MINISTCNN())
# model = utils.LoadModel(model, CurrentPath+"/checkpointEntropySearch.pkl")
optimizer = optim.SGD(modelMinist.parameters(), lr=0.0005)
Train.TrainEpochs(modelMinist, traindata, optimizer,
testdata, 100, 1, linePre)
utils.SaveModel(modelMinist, CurrentPath+"/checkpointMnistTrain.pkl")