74 lines
2.0 KiB
Python
74 lines
2.0 KiB
Python
|
from __future__ import print_function
|
||
|
|
||
|
import os
|
||
|
import sys
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch.optim as optim
|
||
|
import torchvision
|
||
|
from torchvision import datasets, transforms
|
||
|
import torchvision.models as models
|
||
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
|
||
|
sys.path.append(CurrentPath+'../tools')
|
||
|
sys.path.append(CurrentPath+'../')
|
||
|
|
||
|
from tools import UniModule
|
||
|
from tools import utils, Train, Loader, WebVisual
|
||
|
|
||
|
|
||
|
class MINISTCNN(UniModule.ModuleBase):
|
||
|
def __init__(self):
|
||
|
super(MINISTCNN, self).__init__()
|
||
|
self.layer1 = nn.Sequential( # 28-4=24 /2=12
|
||
|
nn.Conv2d(1, 16, kernel_size=5),
|
||
|
nn.BatchNorm2d(16),
|
||
|
nn.ReLU(),
|
||
|
nn.MaxPool2d(kernel_size=2, stride=2))
|
||
|
self.layer2 = nn.Sequential( # 12-2=10 /2=5
|
||
|
nn.Conv2d(16, 32, kernel_size=3),
|
||
|
nn.BatchNorm2d(32),
|
||
|
nn.ReLU(),
|
||
|
nn.MaxPool2d(kernel_size=2, stride=2))
|
||
|
self.layer3 = nn.Sequential( # 5-2=3
|
||
|
nn.Conv2d(32, 64, kernel_size=3),
|
||
|
nn.BatchNorm2d(64),
|
||
|
nn.ReLU()
|
||
|
)
|
||
|
self.fc = nn.Sequential(
|
||
|
nn.Linear(64*3*3, 512),
|
||
|
# nn.ReLU(),
|
||
|
nn.Linear(512, 10))
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.layer1(x)
|
||
|
x = self.layer2(x)
|
||
|
x = self.layer3(x)
|
||
|
x = x.view(x.size(0), -1)
|
||
|
x = self.fc(x)
|
||
|
return F.log_softmax(x, dim=1)
|
||
|
|
||
|
|
||
|
|
||
|
WebVisual.InitVisdom()
|
||
|
window = WebVisual.LineWin()
|
||
|
linePre = WebVisual.Line(window, "linePre")
|
||
|
|
||
|
traindata, testdata = Loader.MNIST(
|
||
|
32, num_workers=4, trainsize=60000, shuffle=True)
|
||
|
|
||
|
modelMinist = utils.SetDevice(MINISTCNN())
|
||
|
|
||
|
# model = utils.LoadModel(model, CurrentPath+"/checkpointEntropySearch.pkl")
|
||
|
|
||
|
optimizer = optim.SGD(modelMinist.parameters(), lr=0.0005)
|
||
|
Train.TrainEpochs(modelMinist, traindata, optimizer,
|
||
|
testdata, 100, 1, linePre)
|
||
|
utils.SaveModel(modelMinist, CurrentPath+"/checkpointMnistTrain.pkl")
|
||
|
|
||
|
|