from __future__ import print_function import os import sys import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import datasets, transforms import torchvision.models as models import matplotlib.pyplot as plt import numpy as np CurrentPath = os.path.split(os.path.realpath(__file__))[0]+"/" sys.path.append(CurrentPath+'../tools') sys.path.append(CurrentPath+'../') from tools import UniModule from tools import utils, Train, Loader, WebVisual class MINISTCNN(UniModule.ModuleBase): def __init__(self): super(MINISTCNN, self).__init__() self.layer1 = nn.Sequential( # 28-4=24 /2=12 nn.Conv2d(1, 16, kernel_size=5), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.layer2 = nn.Sequential( # 12-2=10 /2=5 nn.Conv2d(16, 32, kernel_size=3), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.layer3 = nn.Sequential( # 5-2=3 nn.Conv2d(32, 64, kernel_size=3), nn.BatchNorm2d(64), nn.ReLU() ) self.fc = nn.Sequential( nn.Linear(64*3*3, 512), # nn.ReLU(), nn.Linear(512, 10)) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = x.view(x.size(0), -1) x = self.fc(x) return F.log_softmax(x, dim=1) WebVisual.InitVisdom() window = WebVisual.LineWin() linePre = WebVisual.Line(window, "linePre") traindata, testdata = Loader.MNIST( 32, num_workers=4, trainsize=60000, shuffle=True) modelMinist = utils.SetDevice(MINISTCNN()) # model = utils.LoadModel(model, CurrentPath+"/checkpointEntropySearch.pkl") optimizer = optim.SGD(modelMinist.parameters(), lr=0.0005) Train.TrainEpochs(modelMinist, traindata, optimizer, testdata, 100, 1, linePre) utils.SaveModel(modelMinist, CurrentPath+"/checkpointMnistTrain.pkl")