Add bn layer to 3335,Get if it is benify that.

This commit is contained in:
Colin 2019-11-06 15:17:14 +08:00
parent c2e149fb68
commit 8bb2c5562d
3 changed files with 29 additions and 3 deletions

1
.gitignore vendored
View File

@ -8,3 +8,4 @@ Dataset/
/*/__pycache__ /*/__pycache__
.mypy_cache .mypy_cache
/*/image* /*/image*
/*/visdom.server.log

View File

@ -60,6 +60,23 @@ class Net3335(UniModule.ModuleBase):
x = x.view(-1, 1*10) x = x.view(-1, 1*10)
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
class Net3335BN(UniModule.ModuleBase):
def __init__(self):
super(Net3335BN, self).__init__()
layers = []
layers += [nn.Conv2d(1, 8, kernel_size=3,bias=False,padding=1),nn.MaxPool2d(kernel_size=2, stride=2),nn.Sigmoid()]
layers += [nn.BatchNorm2d(8)]
layers += [nn.Conv2d(8, 8, kernel_size=3,bias=False),nn.MaxPool2d(kernel_size=2, stride=2),nn.Sigmoid()]
layers += [nn.BatchNorm2d(8)]
layers += [nn.Conv2d(8, 8, kernel_size=3,bias=False),nn.Sigmoid()]
layers += [nn.BatchNorm2d(8)]
layers += [nn.Conv2d(8, 10, kernel_size=5,bias=False)]
self.features = nn.Sequential(*layers)
def forward(self, x):
x = self.features(x)
x = x.view(-1, 1*10)
return F.log_softmax(x, dim=1)
class Net3Grad335(UniModule.ModuleBase): class Net3Grad335(UniModule.ModuleBase):
def __init__(self): def __init__(self):
super(Net3Grad335, self).__init__() super(Net3Grad335, self).__init__()

View File

@ -51,6 +51,7 @@ traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=4,shuffle=True,t
WebVisual.InitVisdom() WebVisual.InitVisdom()
window = WebVisual.LineWin() window = WebVisual.LineWin()
lineNoPre = WebVisual.Line(window, "NoPre") lineNoPre = WebVisual.Line(window, "NoPre")
lineNoPreBN = WebVisual.Line(window, "NoPreBN")
linePretrainSearch = WebVisual.Line(window, "PretrainSearch") linePretrainSearch = WebVisual.Line(window, "PretrainSearch")
linePretrainTrain = WebVisual.Line(window, "PretrainTrain") linePretrainTrain = WebVisual.Line(window, "PretrainTrain")
@ -58,10 +59,16 @@ linePretrainTrain = WebVisual.Line(window, "PretrainTrain")
model = utils.SetDevice(Model.Net3Grad335()) # model = utils.SetDevice(Model.Net3Grad335())
model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") # model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl")
# optimizer = optim.SGD(model.parameters(), lr=0.1)
# Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain)
model = utils.SetDevice(Model.Net3335BN())
# model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl")
optimizer = optim.SGD(model.parameters(), lr=0.1) optimizer = optim.SGD(model.parameters(), lr=0.1)
Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain) Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPreBN)
model = utils.SetDevice(Model.Net3335()) model = utils.SetDevice(Model.Net3335())
@ -70,6 +77,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.1)
Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPre) Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPre)
# model = utils.SetDevice(Model.Net3Grad335()) # model = utils.SetDevice(Model.Net3Grad335())
# model = utils.LoadModel(model, CurrentPath+"/checkpointSearch.pkl") # model = utils.LoadModel(model, CurrentPath+"/checkpointSearch.pkl")
# optimizer = optim.SGD(model.parameters(), lr=0.1) # optimizer = optim.SGD(model.parameters(), lr=0.1)