diff --git a/MnistDemo/Minist.py b/MnistDemo/Minist.py new file mode 100644 index 0000000..314aa04 --- /dev/null +++ b/MnistDemo/Minist.py @@ -0,0 +1,73 @@ +from __future__ import print_function + +import os +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +from torchvision import datasets, transforms +import torchvision.models as models +import matplotlib.pyplot as plt +import numpy as np + + +CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/" +sys.path.append(CurrentPath+'../tools') +sys.path.append(CurrentPath+'../') + +from tools import UniModule +from tools import utils, Train, Loader, WebVisual + + +class MINISTCNN(UniModule.ModuleBase): + def __init__(self): + super(MINISTCNN, self).__init__() + self.layer1 = nn.Sequential( # 28-4=24 /2=12 + nn.Conv2d(1, 16, kernel_size=5), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2)) + self.layer2 = nn.Sequential( # 12-2=10 /2=5 + nn.Conv2d(16, 32, kernel_size=3), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.MaxPool2d(kernel_size=2, stride=2)) + self.layer3 = nn.Sequential( # 5-2=3 + nn.Conv2d(32, 64, kernel_size=3), + nn.BatchNorm2d(64), + nn.ReLU() + ) + self.fc = nn.Sequential( + nn.Linear(64*3*3, 512), + # nn.ReLU(), + nn.Linear(512, 10)) + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return F.log_softmax(x, dim=1) + + + +WebVisual.InitVisdom() +window = WebVisual.LineWin() +linePre = WebVisual.Line(window, "linePre") + +traindata, testdata = Loader.MNIST( + 32, num_workers=4, trainsize=60000, shuffle=True) + +modelMinist = utils.SetDevice(MINISTCNN()) + +# model = utils.LoadModel(model, CurrentPath+"/checkpointEntropySearch.pkl") + +optimizer = optim.SGD(modelMinist.parameters(), lr=0.0005) +Train.TrainEpochs(modelMinist, traindata, optimizer, + testdata, 100, 1, linePre) +utils.SaveModel(modelMinist, CurrentPath+"/checkpointMnistTrain.pkl") + + diff --git a/MnistDemo/checkpointMnistTrain.pkl b/MnistDemo/checkpointMnistTrain.pkl new file mode 100644 index 0000000..e84eb44 Binary files /dev/null and b/MnistDemo/checkpointMnistTrain.pkl differ