Speed up by multi cpu to handle histogram

This commit is contained in:
c 2019-12-18 18:22:59 +08:00
parent 418e79df60
commit 8404d93880
1 changed files with 63 additions and 14 deletions

View File

@ -179,12 +179,15 @@ def UnsuperLearnFindBestWeight(netmodel, layer, weight, dataloader, databatchs=1
interationbar.close()
return bestweight,indexs[sortindex]
# search best weight from random data
def UnsuperLearnSearchBestWeight(netmodel, layer, dataloader, databatchs=128, stepsize=1000, interation=1000):
interationbar = tqdm(total=interation)
forwardlayer = layer -1
bestweight = []
bestentropy = 100.0
netmodel.eval()
tl = netmodel.features[layer]
outchannels = tl.out_channels
@ -208,6 +211,40 @@ def UnsuperLearnSearchBestWeight(netmodel, layer, dataloader, databatchs=128, st
for i in range(outchannels):
shift.append(1<<i)
shift = utils.SetDevice(torch.from_numpy(np.array(shift).astype("uint8")))
# shift = torch.from_numpy(np.array(shift).astype("uint8"))
bestweight = []
bestentropy = [100.0]
bittedset = []
bittedLock = threading.Lock()
bestLock = threading.Lock()
class CPU (threading.Thread):
def __init__(self):
threading.Thread.__init__(self)
def run(self):
hasdata = 0
bittedLock.acquire()
if len(bittedset)>0:
bitted = bittedset.pop()
hasdata = 1
bittedLock.release()
if hasdata > 0:
entropys = []
for i in range(len(indexs)):
histced = bitted[:,i].histc(256,0,255).type(torch.float32)
histced = histced[histced>0]
histced = histced/histced.sum()
entropy = (histced.log2()*histced).sum()
entropys.append(entropy.numpy())
argmin = np.argmin(entropys)
bestLock.acquire()
if entropys[argmin] < bestentropy[0]:
bestweight = newweight[indexs[argmin]]
bestentropy[0] = entropys[argmin]
print("finded better entropy")
bestLock.release()
for j in range(interation):
newweightshape = list(newlayer.weight.data.shape)
@ -228,21 +265,33 @@ def UnsuperLearnSearchBestWeight(netmodel, layer, dataloader, databatchs=128, st
# 1000 8
meaned = reshaped.mean(0)
# 102400 1000
bitted = ((reshaped > meaned)* shift).sum(2).type(torch.float32)
bitted = ((reshaped > meaned)* shift).sum(2).type(torch.float32).detach().cpu()
entropys = []
for i in range(len(indexs)):
histced = bitted[:,i].histc(256,0,255).type(torch.float32)
histced = histced[histced>0]
histced = histced/histced.sum()
entropy = (histced.log2()*histced).sum()
entropys.append(entropy.detach().cpu().numpy())
bittedLock.acquire()
bittedset.append(bitted)
bittedLock.release()
threadcpu = CPU()
threadcpu.start()
# entropys = []
# for i in range(len(indexs)):
# histced = bitted[:,i].histc(256,0,255).type(torch.float32)
# histced = histced[histced>0]
# histced = histced/histced.sum()
# entropy = (histced.log2()*histced).sum()
# entropys.append(entropy.detach().cpu().numpy())
# argmin = np.argmin(entropys)
# if entropys[argmin] < bestentropy:
# bestweight = newweight[indexs[argmin]]
# bestentropy = entropys[argmin]
argmin = np.argmin(entropys)
if entropys[argmin] < bestentropy:
bestweight = newweight[indexs[argmin]]
bestentropy = entropys[argmin]
interationbar.update(1)
interationbar.set_description("left:"+str(len(bittedset)))
while bittedset:
time.sleep(100)
interationbar.close()
return bestweight