From 8bb2c5562d1193f86f47ff876b204bb9c66c7ee1 Mon Sep 17 00:00:00 2001 From: Colin Date: Wed, 6 Nov 2019 15:17:14 +0800 Subject: [PATCH] Add bn layer to 3335,Get if it is benify that. --- .gitignore | 1 + FilterEvaluator/Model.py | 17 +++++++++++++++++ FilterEvaluator/TrainNetwork.py | 14 +++++++++++--- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 452be8f..b9d51f7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ Dataset/ /*/__pycache__ .mypy_cache /*/image* +/*/visdom.server.log \ No newline at end of file diff --git a/FilterEvaluator/Model.py b/FilterEvaluator/Model.py index b67007a..bffba21 100644 --- a/FilterEvaluator/Model.py +++ b/FilterEvaluator/Model.py @@ -60,6 +60,23 @@ class Net3335(UniModule.ModuleBase): x = x.view(-1, 1*10) return F.log_softmax(x, dim=1) +class Net3335BN(UniModule.ModuleBase): + def __init__(self): + super(Net3335BN, self).__init__() + layers = [] + layers += [nn.Conv2d(1, 8, kernel_size=3,bias=False,padding=1),nn.MaxPool2d(kernel_size=2, stride=2),nn.Sigmoid()] + layers += [nn.BatchNorm2d(8)] + layers += [nn.Conv2d(8, 8, kernel_size=3,bias=False),nn.MaxPool2d(kernel_size=2, stride=2),nn.Sigmoid()] + layers += [nn.BatchNorm2d(8)] + layers += [nn.Conv2d(8, 8, kernel_size=3,bias=False),nn.Sigmoid()] + layers += [nn.BatchNorm2d(8)] + layers += [nn.Conv2d(8, 10, kernel_size=5,bias=False)] + self.features = nn.Sequential(*layers) + def forward(self, x): + x = self.features(x) + x = x.view(-1, 1*10) + return F.log_softmax(x, dim=1) + class Net3Grad335(UniModule.ModuleBase): def __init__(self): super(Net3Grad335, self).__init__() diff --git a/FilterEvaluator/TrainNetwork.py b/FilterEvaluator/TrainNetwork.py index 25704a0..9b95b55 100644 --- a/FilterEvaluator/TrainNetwork.py +++ b/FilterEvaluator/TrainNetwork.py @@ -51,6 +51,7 @@ traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=4,shuffle=True,t WebVisual.InitVisdom() window = WebVisual.LineWin() lineNoPre = WebVisual.Line(window, "NoPre") +lineNoPreBN = WebVisual.Line(window, "NoPreBN") linePretrainSearch = WebVisual.Line(window, "PretrainSearch") linePretrainTrain = WebVisual.Line(window, "PretrainTrain") @@ -58,10 +59,16 @@ linePretrainTrain = WebVisual.Line(window, "PretrainTrain") -model = utils.SetDevice(Model.Net3Grad335()) -model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") +# model = utils.SetDevice(Model.Net3Grad335()) +# model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") +# optimizer = optim.SGD(model.parameters(), lr=0.1) +# Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain) + + +model = utils.SetDevice(Model.Net3335BN()) +# model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") optimizer = optim.SGD(model.parameters(), lr=0.1) -Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain) +Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPreBN) model = utils.SetDevice(Model.Net3335()) @@ -70,6 +77,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.1) Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPre) + # model = utils.SetDevice(Model.Net3Grad335()) # model = utils.LoadModel(model, CurrentPath+"/checkpointSearch.pkl") # optimizer = optim.SGD(model.parameters(), lr=0.1)