refine Search Weight calculater time.
This commit is contained in:
		
							parent
							
								
									76c9b7f05e
								
							
						
					
					
						commit
						130a6942b9
					
				|  | @ -28,7 +28,7 @@ from tools import utils, Train, Loader, WebVisual | ||||||
| import EvaluatorUnsuper | import EvaluatorUnsuper | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| batchsize = 128 | batchsize = 64 | ||||||
| 
 | 
 | ||||||
| # model = utils.SetDevice(Model.Net5Grad35()) | # model = utils.SetDevice(Model.Net5Grad35()) | ||||||
| # model = utils.SetDevice(Model.Net31535()) | # model = utils.SetDevice(Model.Net31535()) | ||||||
|  | @ -36,6 +36,9 @@ model = utils.SetDevice(Model.Net3Grad335()) | ||||||
| # model = utils.SetDevice(Model.Net3()) | # model = utils.SetDevice(Model.Net3()) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| layers = model.PrintLayer() | layers = model.PrintLayer() | ||||||
| layer = 0 | layer = 0 | ||||||
| # model = utils.LoadModel(model, CurrentPath+"/checkpoint.pkl") | # model = utils.LoadModel(model, CurrentPath+"/checkpoint.pkl") | ||||||
|  | @ -53,8 +56,27 @@ traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=0, shuffle=True) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # weight = EvaluatorUnsuper.UnsuperLearnSearchWeight(model, layer, traindata, NumSearch=100000, SearchChannelRatio=8, Interation=10) | 
 | ||||||
|  | 
 | ||||||
|  | # a = [] | ||||||
|  | # for batch_idx, (data, target) in enumerate(traindata): | ||||||
|  | #     a = torch.jit.trace(model, data) | ||||||
|  | #     break | ||||||
|  | # print(a.graph) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # weight = EvaluatorUnsuper.UnsuperLearnSearchWeight(model, layer, traindata, NumSearch=100000, SearchChannelRatio=32, Interation=10) | ||||||
| # np.save("WeightSearch.npy", weight) | # np.save("WeightSearch.npy", weight) | ||||||
|  | weight = np.load(CurrentPath+"WeightSearch.npy") | ||||||
|  | utils.NumpyToImage(weight, CurrentPath+"image",title="SearchWeight") | ||||||
| # weight = np.load(CurrentPath+"WeightSearch.npy") | # weight = np.load(CurrentPath+"WeightSearch.npy") | ||||||
| # bestweight,index = EvaluatorUnsuper.UnsuperLearnFindBestWeight(model,layer,weight,traindata,128,100000) | # bestweight,index = EvaluatorUnsuper.UnsuperLearnFindBestWeight(model,layer,weight,traindata,128,100000) | ||||||
| # np.save(CurrentPath+"bestweightSearch.npy", bestweight) | # np.save(CurrentPath+"bestweightSearch.npy", bestweight) | ||||||
|  | @ -68,11 +90,12 @@ traindata, testdata = Loader.Cifar10Mono(batchsize, num_workers=0, shuffle=True) | ||||||
| 
 | 
 | ||||||
| # weight = EvaluatorUnsuper.UnsuperLearnTrainWeight(model, layer, traindata, NumTrain=5000) | # weight = EvaluatorUnsuper.UnsuperLearnTrainWeight(model, layer, traindata, NumTrain=5000) | ||||||
| # np.save("WeightTrain.npy", weight) | # np.save("WeightTrain.npy", weight) | ||||||
| weight = np.load(CurrentPath+"WeightTrain.npy") | # utils.NumpyToImage(bestweight, CurrentPath+"image",title="TrainWeight") | ||||||
| bestweight, index = EvaluatorUnsuper.UnsuperLearnFindBestWeight(model, layer, weight, traindata, databatchs=128, interation=1000000) | # weight = np.load(CurrentPath+"WeightTrain.npy") | ||||||
| np.save(CurrentPath+"bestweightTrain.npy", bestweight) | # bestweight, index = EvaluatorUnsuper.UnsuperLearnFindBestWeight(model, layer, weight, traindata, databatchs=64, interation=1000000) | ||||||
| bestweight = np.load(CurrentPath+"bestweightTrain.npy") | # np.save(CurrentPath+"bestweightTrain.npy", bestweight) | ||||||
| utils.NumpyToImage(bestweight, CurrentPath+"image") | # bestweight = np.load(CurrentPath+"bestweightTrain.npy") | ||||||
|  | # utils.NumpyToImage(bestweight, CurrentPath+"image") | ||||||
| # EvaluatorUnsuper.SetModelConvWeight(model,layer,bestweight) | # EvaluatorUnsuper.SetModelConvWeight(model,layer,bestweight) | ||||||
| # utils.SaveModel(model,CurrentPath+"/checkpointTrain.pkl") | # utils.SaveModel(model,CurrentPath+"/checkpointTrain.pkl") | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -25,35 +25,15 @@ import Model as Model | ||||||
| from tools import utils, Train, Loader | from tools import utils, Train, Loader | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def GetScore(netmodel,layer,SearchLayer,DataSet,Interation=-1): | def GetScore(netmodel,layer,SearchLayer,DataSet): | ||||||
|     netmodel.eval() |     netmodel.eval() | ||||||
|     sample = utils.SetDevice(torch.empty((SearchLayer.out_channels,0))) |     sample = utils.SetDevice(torch.empty((SearchLayer.out_channels,0))) | ||||||
| 
 |  | ||||||
|     layer = layer-1 |     layer = layer-1 | ||||||
| 
 |     # for data in DataSet: | ||||||
|     # layerout = [] |     output = netmodel.ForwardLayer(DataSet,layer) | ||||||
|     #     layerint = [] |     output = SearchLayer(output) | ||||||
|     #     def getnet(self, input, output): |     output = torch.reshape(output.transpose(0,1),(SearchLayer.out_channels,-1)) | ||||||
|     #         layerout.append(output) |     sample = torch.cat((sample,output),1) | ||||||
|     #         layerint.append(input) |  | ||||||
|     # handle = netmodel.features[layer].register_forward_hook(getnet) |  | ||||||
|     # netmodel.ForwardLayer(data,layer) |  | ||||||
|     # output = layerout[0][:,:,:,:] |  | ||||||
|     # handle.remove() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     for batch_idx, (data, target) in enumerate(DataSet): |  | ||||||
|         data = utils.SetDevice(data) |  | ||||||
|         target = utils.SetDevice(target) |  | ||||||
|         output = netmodel.ForwardLayer(data,layer) |  | ||||||
|         output = SearchLayer(output) |  | ||||||
|         data.detach() |  | ||||||
|         target.detach() |  | ||||||
|          |  | ||||||
|         output = torch.reshape(output.transpose(0,1),(SearchLayer.out_channels,-1)) |  | ||||||
|         sample = torch.cat((sample,output),1) |  | ||||||
|         if Interation > 0 and batch_idx >= (Interation-1): |  | ||||||
|             break |  | ||||||
|          |          | ||||||
|     sample_mean=torch.mean(sample,dim=1,keepdim=True) |     sample_mean=torch.mean(sample,dim=1,keepdim=True) | ||||||
|     dat1 = torch.mean(torch.abs(sample - sample_mean),dim=1,keepdim=True) |     dat1 = torch.mean(torch.abs(sample - sample_mean),dim=1,keepdim=True) | ||||||
|  | @ -61,36 +41,58 @@ def GetScore(netmodel,layer,SearchLayer,DataSet,Interation=-1): | ||||||
|     dat2 = torch.mean(dat2 * dat2,dim=1) |     dat2 = torch.mean(dat2 * dat2,dim=1) | ||||||
|     return dat2.cpu().detach().numpy() |     return dat2.cpu().detach().numpy() | ||||||
| 
 | 
 | ||||||
| def UnsuperLearnSearchWeight(model, layer, dataloader, NumSearch=10000, SaveChannelRatio=500, SearchChannelRatio=1, Interation=10): | def UnsuperLearnSearchWeight(model, layer, dataloader, NumSearch=10000, SaveChannel=4000, SearchChannelRatio=1, Interation=10): | ||||||
|  |     interationbar = tqdm(total=NumSearch)         | ||||||
|     tl = model.features[layer] |     tl = model.features[layer] | ||||||
|     newlayer = nn.Conv2d(tl.in_channels, tl.out_channels * SearchChannelRatio, tl.kernel_size, |     newlayer = nn.Conv2d(tl.in_channels, tl.out_channels * SearchChannelRatio, tl.kernel_size, | ||||||
|                          tl.stride, tl.padding, tl.dilation, tl.groups, tl.bias, tl.padding_mode) |                          tl.stride, tl.padding, tl.dilation, tl.groups, tl.bias, tl.padding_mode) | ||||||
|     newlayer = utils.SetDevice(newlayer) |     newlayer = utils.SetDevice(newlayer) | ||||||
| 
 | 
 | ||||||
|     newchannels = tl.out_channels * SaveChannelRatio |  | ||||||
|     newweightshape = list(newlayer.weight.data.shape) |     newweightshape = list(newlayer.weight.data.shape) | ||||||
|      |     newweightshape.insert(0,NumSearch) | ||||||
|     minactive = np.empty((0)) |     minactive = np.empty((0)) | ||||||
|     minweight = np.empty([0,newweightshape[1],newweightshape[2],newweightshape[3]]) |     minweight = np.empty([0,newweightshape[-3],newweightshape[-2],newweightshape[-1]]) | ||||||
|  |     newweight = np.random.uniform(-1.0,1.0,newweightshape).astype("float32") | ||||||
|  | 
 | ||||||
|  |     dataset = [] | ||||||
|  |     for batch_idx, (data, target) in enumerate(dataloader): | ||||||
|  |         data = utils.SetDevice(data) | ||||||
|  |         target = utils.SetDevice(target) | ||||||
|  |         dataset.append(data) | ||||||
|  |         if Interation > 0 and batch_idx >= (Interation-1): | ||||||
|  |             break | ||||||
|  |     dataset = torch.cat(dataset,0) | ||||||
|  |     model.eval() | ||||||
| 
 | 
 | ||||||
|     for i in range(NumSearch): |     for i in range(NumSearch): | ||||||
|         newweight = np.random.uniform(-1.0,1.0,newweightshape).astype("float32") |         newlayer.weight.data=utils.SetDevice(torch.from_numpy(newweight[i])) | ||||||
|         newlayer.weight.data=utils.SetDevice(torch.from_numpy(newweight)) |  | ||||||
| 
 | 
 | ||||||
|         score = GetScore(model, layer, newlayer, dataloader, Interation) |  | ||||||
| 
 | 
 | ||||||
|  |         output = model.ForwardLayer(dataset,layer-1) | ||||||
|  |         output = newlayer(output) | ||||||
|  |         output = torch.reshape(output.transpose(0,1),(newlayer.out_channels,-1)) | ||||||
|  | 
 | ||||||
|  |         sample_mean=torch.mean(output,dim=1,keepdim=True) | ||||||
|  |         dat1 = torch.mean(torch.abs(output - sample_mean),dim=1,keepdim=True) | ||||||
|  |         dat2 = (output - sample_mean)/dat1 | ||||||
|  |         dat2 = torch.mean(dat2 * dat2,dim=1) | ||||||
|  |         score = dat2.cpu().detach().numpy() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         # score = GetScore(model, layer, newlayer, dataset) | ||||||
|         minactive = np.append(minactive, score) |         minactive = np.append(minactive, score) | ||||||
|         minweight = np.concatenate((minweight, newweight)) |         minweight = np.concatenate((minweight, newweight[i])) | ||||||
| 
 | 
 | ||||||
|         index = minactive.argsort() |         index = minactive.argsort() | ||||||
|         minactive = minactive[index[0:newchannels]] |         minactive = minactive[index[0:SaveChannel]] | ||||||
|         minweight = minweight[index[0:newchannels]] |         minweight = minweight[index[0:SaveChannel]] | ||||||
|         print("search random :" + str(i)) |  | ||||||
|         if i % (NumSearch/10) == 0: |         if i % (NumSearch/10) == 0: | ||||||
|             tl.data=utils.SetDevice(torch.from_numpy(minweight[0:tl.out_channels])) |             tl.data=utils.SetDevice(torch.from_numpy(minweight[0:tl.out_channels])) | ||||||
|             utils.SaveModel(model, CurrentPath+"/checkpoint.pkl") |             utils.SaveModel(model, CurrentPath+"/checkpoint.pkl") | ||||||
|  |         interationbar.update(1) | ||||||
|      |      | ||||||
|     tl.data=utils.SetDevice(torch.from_numpy(minweight[0:tl.out_channels])) |     tl.data=utils.SetDevice(torch.from_numpy(minweight[0:tl.out_channels])) | ||||||
|  |     interationbar.close() | ||||||
|     return minweight |     return minweight | ||||||
| 
 | 
 | ||||||
| def TrainLayer(netmodel, layer, SearchLayer, DataSet, Epoch=100): | def TrainLayer(netmodel, layer, SearchLayer, DataSet, Epoch=100): | ||||||
|  |  | ||||||
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 12 KiB | 
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 12 KiB | 
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 12 KiB | 
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 11 KiB | 
										
											Binary file not shown.
										
									
								
							
		Loading…
	
		Reference in New Issue