refine Search Weight calculater time.
This commit is contained in:
parent
76c9b7f05e
commit
130a6942b9
|
@ -28,7 +28,7 @@ from tools import utils, Train, Loader, WebVisual
|
|||
import EvaluatorUnsuper
|
||||
|
||||
|
||||
batchsize = 128
|
||||
batchsize = 64
|
||||
|
||||
# model = utils.SetDevice(Model.Net5Grad35())
|
||||
# model = utils.SetDevice(Model.Net31535())
|
||||
|
@ -36,6 +36,9 @@ model = utils.SetDevice(Model.Net3Grad335())
|
|||
# model = utils.SetDevice(Model.Net3())
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
layers = model.PrintLayer()
|
||||
layer = 0
|
||||
# model = utils.LoadModel(model, CurrentPath+"/checkpoint.pkl")
|
||||
|
@ -53,8 +56,27 @@ traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=0, shuffle=True)
|
|||
|
||||
|
||||
|
||||
# weight = EvaluatorUnsuper.UnsuperLearnSearchWeight(model, layer, traindata, NumSearch=100000, SearchChannelRatio=8, Interation=10)
|
||||
|
||||
|
||||
# a = []
|
||||
# for batch_idx, (data, target) in enumerate(traindata):
|
||||
# a = torch.jit.trace(model, data)
|
||||
# break
|
||||
# print(a.graph)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# weight = EvaluatorUnsuper.UnsuperLearnSearchWeight(model, layer, traindata, NumSearch=100000, SearchChannelRatio=32, Interation=10)
|
||||
# np.save("WeightSearch.npy", weight)
|
||||
weight = np.load(CurrentPath+"WeightSearch.npy")
|
||||
utils.NumpyToImage(weight, CurrentPath+"image",title="SearchWeight")
|
||||
# weight = np.load(CurrentPath+"WeightSearch.npy")
|
||||
# bestweight,index = EvaluatorUnsuper.UnsuperLearnFindBestWeight(model,layer,weight,traindata,128,100000)
|
||||
# np.save(CurrentPath+"bestweightSearch.npy", bestweight)
|
||||
|
@ -68,11 +90,12 @@ traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=0, shuffle=True)
|
|||
|
||||
# weight = EvaluatorUnsuper.UnsuperLearnTrainWeight(model, layer, traindata, NumTrain=5000)
|
||||
# np.save("WeightTrain.npy", weight)
|
||||
weight = np.load(CurrentPath+"WeightTrain.npy")
|
||||
bestweight, index = EvaluatorUnsuper.UnsuperLearnFindBestWeight(model, layer, weight, traindata, databatchs=128, interation=1000000)
|
||||
np.save(CurrentPath+"bestweightTrain.npy", bestweight)
|
||||
bestweight = np.load(CurrentPath+"bestweightTrain.npy")
|
||||
utils.NumpyToImage(bestweight, CurrentPath+"image")
|
||||
# utils.NumpyToImage(bestweight, CurrentPath+"image",title="TrainWeight")
|
||||
# weight = np.load(CurrentPath+"WeightTrain.npy")
|
||||
# bestweight, index = EvaluatorUnsuper.UnsuperLearnFindBestWeight(model, layer, weight, traindata, databatchs=64, interation=1000000)
|
||||
# np.save(CurrentPath+"bestweightTrain.npy", bestweight)
|
||||
# bestweight = np.load(CurrentPath+"bestweightTrain.npy")
|
||||
# utils.NumpyToImage(bestweight, CurrentPath+"image")
|
||||
# EvaluatorUnsuper.SetModelConvWeight(model,layer,bestweight)
|
||||
# utils.SaveModel(model,CurrentPath+"/checkpointTrain.pkl")
|
||||
|
||||
|
|
|
@ -25,72 +25,74 @@ import Model as Model
|
|||
from tools import utils, Train, Loader
|
||||
|
||||
|
||||
def GetScore(netmodel,layer,SearchLayer,DataSet,Interation=-1):
|
||||
def GetScore(netmodel,layer,SearchLayer,DataSet):
|
||||
netmodel.eval()
|
||||
sample = utils.SetDevice(torch.empty((SearchLayer.out_channels,0)))
|
||||
|
||||
layer = layer-1
|
||||
|
||||
# layerout = []
|
||||
# layerint = []
|
||||
# def getnet(self, input, output):
|
||||
# layerout.append(output)
|
||||
# layerint.append(input)
|
||||
# handle = netmodel.features[layer].register_forward_hook(getnet)
|
||||
# netmodel.ForwardLayer(data,layer)
|
||||
# output = layerout[0][:,:,:,:]
|
||||
# handle.remove()
|
||||
|
||||
|
||||
for batch_idx, (data, target) in enumerate(DataSet):
|
||||
data = utils.SetDevice(data)
|
||||
target = utils.SetDevice(target)
|
||||
output = netmodel.ForwardLayer(data,layer)
|
||||
output = SearchLayer(output)
|
||||
data.detach()
|
||||
target.detach()
|
||||
# for data in DataSet:
|
||||
output = netmodel.ForwardLayer(DataSet,layer)
|
||||
output = SearchLayer(output)
|
||||
output = torch.reshape(output.transpose(0,1),(SearchLayer.out_channels,-1))
|
||||
sample = torch.cat((sample,output),1)
|
||||
|
||||
output = torch.reshape(output.transpose(0,1),(SearchLayer.out_channels,-1))
|
||||
sample = torch.cat((sample,output),1)
|
||||
if Interation > 0 and batch_idx >= (Interation-1):
|
||||
break
|
||||
|
||||
sample_mean=torch.mean(sample,dim=1,keepdim=True)
|
||||
dat1 = torch.mean(torch.abs(sample - sample_mean),dim=1,keepdim=True)
|
||||
dat2 = (sample - sample_mean)/dat1
|
||||
dat2 = torch.mean(dat2 * dat2,dim=1)
|
||||
return dat2.cpu().detach().numpy()
|
||||
|
||||
def UnsuperLearnSearchWeight(model, layer, dataloader, NumSearch=10000, SaveChannelRatio=500, SearchChannelRatio=1, Interation=10):
|
||||
def UnsuperLearnSearchWeight(model, layer, dataloader, NumSearch=10000, SaveChannel=4000, SearchChannelRatio=1, Interation=10):
|
||||
interationbar = tqdm(total=NumSearch)
|
||||
tl = model.features[layer]
|
||||
newlayer = nn.Conv2d(tl.in_channels, tl.out_channels * SearchChannelRatio, tl.kernel_size,
|
||||
tl.stride, tl.padding, tl.dilation, tl.groups, tl.bias, tl.padding_mode)
|
||||
newlayer = utils.SetDevice(newlayer)
|
||||
|
||||
newchannels = tl.out_channels * SaveChannelRatio
|
||||
newweightshape = list(newlayer.weight.data.shape)
|
||||
|
||||
newweightshape.insert(0,NumSearch)
|
||||
minactive = np.empty((0))
|
||||
minweight = np.empty([0,newweightshape[1],newweightshape[2],newweightshape[3]])
|
||||
minweight = np.empty([0,newweightshape[-3],newweightshape[-2],newweightshape[-1]])
|
||||
newweight = np.random.uniform(-1.0,1.0,newweightshape).astype("float32")
|
||||
|
||||
dataset = []
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
data = utils.SetDevice(data)
|
||||
target = utils.SetDevice(target)
|
||||
dataset.append(data)
|
||||
if Interation > 0 and batch_idx >= (Interation-1):
|
||||
break
|
||||
dataset = torch.cat(dataset,0)
|
||||
model.eval()
|
||||
|
||||
for i in range(NumSearch):
|
||||
newweight = np.random.uniform(-1.0,1.0,newweightshape).astype("float32")
|
||||
newlayer.weight.data=utils.SetDevice(torch.from_numpy(newweight))
|
||||
newlayer.weight.data=utils.SetDevice(torch.from_numpy(newweight[i]))
|
||||
|
||||
score = GetScore(model, layer, newlayer, dataloader, Interation)
|
||||
|
||||
output = model.ForwardLayer(dataset,layer-1)
|
||||
output = newlayer(output)
|
||||
output = torch.reshape(output.transpose(0,1),(newlayer.out_channels,-1))
|
||||
|
||||
sample_mean=torch.mean(output,dim=1,keepdim=True)
|
||||
dat1 = torch.mean(torch.abs(output - sample_mean),dim=1,keepdim=True)
|
||||
dat2 = (output - sample_mean)/dat1
|
||||
dat2 = torch.mean(dat2 * dat2,dim=1)
|
||||
score = dat2.cpu().detach().numpy()
|
||||
|
||||
|
||||
# score = GetScore(model, layer, newlayer, dataset)
|
||||
minactive = np.append(minactive, score)
|
||||
minweight = np.concatenate((minweight, newweight))
|
||||
minweight = np.concatenate((minweight, newweight[i]))
|
||||
|
||||
index = minactive.argsort()
|
||||
minactive = minactive[index[0:newchannels]]
|
||||
minweight = minweight[index[0:newchannels]]
|
||||
print("search random :" + str(i))
|
||||
minactive = minactive[index[0:SaveChannel]]
|
||||
minweight = minweight[index[0:SaveChannel]]
|
||||
if i % (NumSearch/10) == 0:
|
||||
tl.data=utils.SetDevice(torch.from_numpy(minweight[0:tl.out_channels]))
|
||||
utils.SaveModel(model, CurrentPath+"/checkpoint.pkl")
|
||||
interationbar.update(1)
|
||||
|
||||
tl.data=utils.SetDevice(torch.from_numpy(minweight[0:tl.out_channels]))
|
||||
interationbar.close()
|
||||
return minweight
|
||||
|
||||
def TrainLayer(netmodel, layer, SearchLayer, DataSet, Epoch=100):
|
||||
|
@ -227,4 +229,4 @@ def CrossDistanceOfKernel(k1, k2, BatchSize1=1, BatchSize2=1):
|
|||
leftshape[0] = -1
|
||||
left = left.reshape(leftshape)
|
||||
|
||||
return DistanceOfKernel(right, left, left.shape[0])
|
||||
return DistanceOfKernel(right, left, left.shape[0])
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
Binary file not shown.
After Width: | Height: | Size: 11 KiB |
Binary file not shown.
Loading…
Reference in New Issue