This commit is contained in:
colin 2019-10-15 15:26:03 +08:00
parent febad1f9c3
commit 8a9acaa94e
7 changed files with 20 additions and 50 deletions

View File

@ -28,7 +28,7 @@ from tools import utils, Train, Loader, WebVisual
import EvaluatorUnsuper import EvaluatorUnsuper
batchsize = 64 batchsize = 128
# model = utils.SetDevice(Model.Net5Grad35()) # model = utils.SetDevice(Model.Net5Grad35())
# model = utils.SetDevice(Model.Net31535()) # model = utils.SetDevice(Model.Net31535())
@ -57,48 +57,18 @@ traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=0, shuffle=False
# weight,active = EvaluatorUnsuper.UnsuperLearnSearchWeight(model, layer, traindata, NumSearch=1000000, SearchChannelRatio=32, Interation=5)
# a = [] # np.save("WeightSearch.npy", weight)
# for batch_idx, (data, target) in enumerate(traindata):
# a = torch.jit.trace(model, data)
# break
# print(a.graph)
# for batch_idx, (data, target) in enumerate(traindata):
# utils.NumpyToImage(data.cpu().detach().numpy(), CurrentPath+"image", title="TrainData")
# break
# weight, active = EvaluatorUnsuper.UnsuperLearnSearchWeight(
# model, layer, traindata, NumSearch=10, SaveChannel=4000, SearchChannelRatio=32, Interation=10)
# utils.NumpyToImage(weight, CurrentPath+"image",title="SearchWeight")
# b =0
weight,active = EvaluatorUnsuper.UnsuperLearnSearchWeight(model, layer, traindata, NumSearch=100000, SearchChannelRatio=32, Interation=10)
np.save("WeightSearch.npy", weight)
weight = np.load(CurrentPath+"WeightSearch.npy")
utils.NumpyToImage(weight, CurrentPath+"image",title="SearchWeight")
# weight = np.load(CurrentPath+"WeightSearch.npy") # weight = np.load(CurrentPath+"WeightSearch.npy")
# bestweight,index = EvaluatorUnsuper.UnsuperLearnFindBestWeight(model,layer,weight,traindata,128,100000) # utils.NumpyToImage(weight, CurrentPath+"image",title="SearchWeight")
# np.save(CurrentPath+"bestweightSearch.npy", bestweight) weight = np.load(CurrentPath+"WeightSearch.npy")
# bestweight = np.load(CurrentPath+"bestweightSearch.npy") weight = weight[0:256]
# utils.NumpyToImage(bestweight, CurrentPath+"image",title="SearchWerightBest") bestweight,index = EvaluatorUnsuper.UnsuperLearnFindBestWeight(model,layer,weight,traindata,32,400000)
# EvaluatorUnsuper.SetModelConvWeight(model,layer,bestweight) np.save(CurrentPath+"bestweightSearch.npy", bestweight)
# utils.SaveModel(model,CurrentPath+"/checkpointSearch.pkl") bestweight = np.load(CurrentPath+"bestweightSearch.npy")
utils.NumpyToImage(bestweight, CurrentPath+"image",title="SearchWerightBest")
EvaluatorUnsuper.SetModelConvWeight(model,layer,bestweight)
utils.SaveModel(model,CurrentPath+"/checkpointSearch.pkl")

View File

@ -57,10 +57,10 @@ linePretrainTrain = WebVisual.Line(window, "PretrainTrain")
model = utils.SetDevice(Model.Net3Grad335()) # model = utils.SetDevice(Model.Net3Grad335())
model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl") # model = utils.LoadModel(model, CurrentPath+"/checkpointTrain.pkl")
optimizer = optim.SGD(model.parameters(), lr=0.1) # optimizer = optim.SGD(model.parameters(), lr=0.1)
Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain) # Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainTrain)
model = utils.SetDevice(Model.Net3335()) model = utils.SetDevice(Model.Net3335())
@ -69,10 +69,10 @@ optimizer = optim.SGD(model.parameters(), lr=0.1)
Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPre) Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,lineNoPre)
# model = utils.SetDevice(Model.Net3Grad335()) model = utils.SetDevice(Model.Net3Grad335())
# model = utils.LoadModel(model, CurrentPath+"/checkpointSearch.pkl") model = utils.LoadModel(model, CurrentPath+"/checkpointSearch.pkl")
# optimizer = optim.SGD(model.parameters(), lr=0.1) optimizer = optim.SGD(model.parameters(), lr=0.1)
# Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainSearch) Train.TrainEpochs(model,traindata,optimizer,testdata,3000,30,linePretrainSearch)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.