add MnistDemo
This commit is contained in:
parent
67c5cc13a0
commit
7ab55557cc
|
@ -0,0 +1,73 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torchvision
|
||||
from torchvision import datasets, transforms
|
||||
import torchvision.models as models
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/"
|
||||
sys.path.append(CurrentPath+'../tools')
|
||||
sys.path.append(CurrentPath+'../')
|
||||
|
||||
from tools import UniModule
|
||||
from tools import utils, Train, Loader, WebVisual
|
||||
|
||||
|
||||
class MINISTCNN(UniModule.ModuleBase):
|
||||
def __init__(self):
|
||||
super(MINISTCNN, self).__init__()
|
||||
self.layer1 = nn.Sequential( # 28-4=24 /2=12
|
||||
nn.Conv2d(1, 16, kernel_size=5),
|
||||
nn.BatchNorm2d(16),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
self.layer2 = nn.Sequential( # 12-2=10 /2=5
|
||||
nn.Conv2d(16, 32, kernel_size=3),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(kernel_size=2, stride=2))
|
||||
self.layer3 = nn.Sequential( # 5-2=3
|
||||
nn.Conv2d(32, 64, kernel_size=3),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(64*3*3, 512),
|
||||
# nn.ReLU(),
|
||||
nn.Linear(512, 10))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
|
||||
WebVisual.InitVisdom()
|
||||
window = WebVisual.LineWin()
|
||||
linePre = WebVisual.Line(window, "linePre")
|
||||
|
||||
traindata, testdata = Loader.MNIST(
|
||||
32, num_workers=4, trainsize=60000, shuffle=True)
|
||||
|
||||
modelMinist = utils.SetDevice(MINISTCNN())
|
||||
|
||||
# model = utils.LoadModel(model, CurrentPath+"/checkpointEntropySearch.pkl")
|
||||
|
||||
optimizer = optim.SGD(modelMinist.parameters(), lr=0.0005)
|
||||
Train.TrainEpochs(modelMinist, traindata, optimizer,
|
||||
testdata, 100, 1, linePre)
|
||||
utils.SaveModel(modelMinist, CurrentPath+"/checkpointMnistTrain.pkl")
|
||||
|
||||
|
Binary file not shown.
Loading…
Reference in New Issue