diff --git a/.gitignore b/.gitignore index 452be8f..b9d51f7 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ Dataset/ /*/__pycache__ .mypy_cache /*/image* +/*/visdom.server.log \ No newline at end of file diff --git a/FilterEvaluator/Model.py b/FilterEvaluator/Model.py index b67007a..bffba21 100644 --- a/FilterEvaluator/Model.py +++ b/FilterEvaluator/Model.py @@ -60,6 +60,23 @@ class Net3335(UniModule.ModuleBase): x = x.view(-1, 1*10) return F.log_softmax(x, dim=1) +class Net3335BN(UniModule.ModuleBase): + def __init__(self): + super(Net3335BN, self).__init__() + layers = [] + layers += [nn.Conv2d(1, 8, kernel_size=3,bias=False,padding=1),nn.MaxPool2d(kernel_size=2, stride=2),nn.Sigmoid()] + layers += [nn.BatchNorm2d(8)] + layers += [nn.Conv2d(8, 8, kernel_size=3,bias=False),nn.MaxPool2d(kernel_size=2, stride=2),nn.Sigmoid()] + layers += [nn.BatchNorm2d(8)] + layers += [nn.Conv2d(8, 8, kernel_size=3,bias=False),nn.Sigmoid()] + layers += [nn.BatchNorm2d(8)] + layers += [nn.Conv2d(8, 10, kernel_size=5,bias=False)] + self.features = nn.Sequential(*layers) + def forward(self, x): + x = self.features(x) + x = x.view(-1, 1*10) + return F.log_softmax(x, dim=1) + class Net3Grad335(UniModule.ModuleBase): def __init__(self): super(Net3Grad335, self).__init__() diff --git a/FilterEvaluator/TrainNetwork.py b/FilterEvaluator/TrainNetwork.py index 25704a0..9b95b55 100644 --- a/FilterEvaluator/TrainNetwork.py +++ b/FilterEvaluator/TrainNetwork.py @@ -51,6 +51,7 @@ traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=4,shuffle=True,t WebVisual.InitVisdom() window = WebVisual.LineWin() lineNoPre = WebVisual.Line(window, "NoPre") +lineNoPreBN = WebVisual.Line(window, "NoPreBN") linePretrainSearch = WebVisual.Line(window, "PretrainSearch") linePretrainTrain = WebVisual.Line(window, "PretrainTrain") @@ -58,10 +59,16 @@ linePretrainTrain = WebVisual.Line(window, "PretrainTrain") -model = utils.SetDevice(Model.Net3Grad335()) -model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") +# model = utils.SetDevice(Model.Net3Grad335()) +# model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") +# optimizer = optim.SGD(model.parameters(), lr=0.1) +# Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain) + + +model = utils.SetDevice(Model.Net3335BN()) +# model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") optimizer = optim.SGD(model.parameters(), lr=0.1) -Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain) +Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPreBN) model = utils.SetDevice(Model.Net3335()) @@ -70,6 +77,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.1) Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPre) + # model = utils.SetDevice(Model.Net3Grad335()) # model = utils.LoadModel(model, CurrentPath+"/checkpointSearch.pkl") # optimizer = optim.SGD(model.parameters(), lr=0.1)