From c2e149fb680f1769644c517c6086243bafc2b741 Mon Sep 17 00:00:00 2001 From: colin Date: Thu, 31 Oct 2019 14:35:47 +0800 Subject: [PATCH] Check if train kernel is benifit to pretrain. --- FilterEvaluator/TrainNetwork.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/FilterEvaluator/TrainNetwork.py b/FilterEvaluator/TrainNetwork.py index 97ea0cd..25704a0 100644 --- a/FilterEvaluator/TrainNetwork.py +++ b/FilterEvaluator/TrainNetwork.py @@ -44,7 +44,8 @@ batchsize = 128 # traindata, testdata = Loader.RandomMnist(batchsize, num_workers=4, style="Horizontal") # traindata, testdata = Loader.RandomMnist(batchsize, num_workers=4, style="VerticalOneLine") # traindata, testdata = Loader.RandomMnist(batchsize, num_workers=4, style="VerticalZebra") -traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=4,shuffle=True,trainsize=500) +# traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=4,shuffle=True,trainsize=500) +traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=4,shuffle=True,trainsize=0) WebVisual.InitVisdom() @@ -57,10 +58,10 @@ linePretrainTrain = WebVisual.Line(window, "PretrainTrain") -# model = utils.SetDevice(Model.Net3Grad335()) -# model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") -# optimizer = optim.SGD(model.parameters(), lr=0.1) -# Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain) +model = utils.SetDevice(Model.Net3Grad335()) +model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") +optimizer = optim.SGD(model.parameters(), lr=0.1) +Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain) model = utils.SetDevice(Model.Net3335()) @@ -69,10 +70,12 @@ optimizer = optim.SGD(model.parameters(), lr=0.1) Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPre) -model = utils.SetDevice(Model.Net3Grad335()) -model = utils.LoadModel(model, CurrentPath+"/checkpointSearch.pkl") -optimizer = optim.SGD(model.parameters(), lr=0.1) -Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainSearch) +# model = utils.SetDevice(Model.Net3Grad335()) +# model = utils.LoadModel(model, CurrentPath+"/checkpointSearch.pkl") +# optimizer = optim.SGD(model.parameters(), lr=0.1) +# Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainSearch) + +