194 lines
6.3 KiB
Python
Executable File
194 lines
6.3 KiB
Python
Executable File
import sys
|
|
import math
|
|
import torch
|
|
import shutil
|
|
import time
|
|
import os
|
|
import random
|
|
from easydict import EasyDict as edict
|
|
import yaml
|
|
import numpy as np
|
|
import argparse
|
|
import cv2
|
|
|
|
class AverageMeter(object):
|
|
""" Computes ans stores the average and current value"""
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0.
|
|
self.avg = 0.
|
|
self.sum = 0.
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
def AdjustLearningRate(optimizermodule, iters, base_lr, policy_parameter, policy='step', multiple=[1]):
|
|
|
|
if policy == 'fixed':
|
|
lr = base_lr
|
|
elif policy == 'step':
|
|
lr = base_lr * (policy_parameter['gamma'] ** (iters // policy_parameter['step_size']))
|
|
elif policy == 'exp':
|
|
lr = base_lr * (policy_parameter['gamma'] ** iters)
|
|
elif policy == 'inv':
|
|
lr = base_lr * ((1 + policy_parameter['gamma'] * iters) ** (-policy_parameter['power']))
|
|
elif policy == 'multistep':
|
|
lr = base_lr
|
|
for stepvalue in policy_parameter['stepvalue']:
|
|
if iters >= stepvalue:
|
|
lr *= policy_parameter['gamma']
|
|
else:
|
|
break
|
|
elif policy == 'poly':
|
|
lr = base_lr * ((1 - iters * 1.0 / policy_parameter['max_iter']) ** policy_parameter['power'])
|
|
elif policy == 'sigmoid':
|
|
lr = base_lr * (1.0 / (1 + math.exp(-policy_parameter['gamma'] * (iters - policy_parameter['stepsize']))))
|
|
elif policy == 'multistep-poly':
|
|
lr = base_lr
|
|
stepstart = 0
|
|
stepend = policy_parameter['max_iter']
|
|
for stepvalue in policy_parameter['stepvalue']:
|
|
if iters >= stepvalue:
|
|
lr *= policy_parameter['gamma']
|
|
stepstart = stepvalue
|
|
else:
|
|
stepend = stepvalue
|
|
break
|
|
lr = max(lr * policy_parameter['gamma'], lr * (1 - (iters - stepstart) * 1.0 / (stepend - stepstart)) ** policy_parameter['power'])
|
|
|
|
for i, param_group in enumerate(optimizermodule.param_groups):
|
|
param_group['lr'] = lr * multiple[i]
|
|
return lr
|
|
|
|
def ConstructModel(args):
|
|
sys.path.append('../network/')
|
|
m = __import__(args.modelname)
|
|
model = m.ConstructModel(args)
|
|
return model
|
|
|
|
|
|
def GetCheckpoint(filename):
|
|
state_dict = torch.load(filename, map_location='cpu')['state_dict']
|
|
from collections import OrderedDict
|
|
new_state_dict = OrderedDict()
|
|
for k, v in state_dict.items():
|
|
name = k
|
|
if k[0:7] == "module.":
|
|
name = k[7:]
|
|
new_state_dict[name] = v
|
|
return new_state_dict
|
|
def SaveCheckpoint(model, filename='checkpoint_'+str(time.time()) + '.pth.tar', is_best=False):
|
|
state = model.state_dict
|
|
torch.save(state, filename)
|
|
if is_best:
|
|
shutil.copyfile(filename, "best_"+filename)
|
|
def LoadCheckpoint(model , pretrained):
|
|
if pretrained != 'None' and os.path.exists(pretrained):
|
|
model.load_state_dict(GetCheckpoint(pretrained))
|
|
return model
|
|
else:
|
|
return ""
|
|
|
|
|
|
def SaveModel(model , filename='checkpoint_'+str(time.time()) + '.pkl'):
|
|
torch.save({
|
|
# 'epoch': epoch,
|
|
'model_state_dict': model.state_dict(),
|
|
# 'optimizer_state_dict': optimizer.state_dict(),
|
|
# 'loss': loss,
|
|
}, filename)
|
|
def LoadModel(model, filename):
|
|
checkpoint = torch.load(filename, map_location="cpu")
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
# epoch = checkpoint['epoch']
|
|
# loss = checkpoint['loss']
|
|
return model
|
|
|
|
|
|
def SetDevice(Obj , deviceid=range(torch.cuda.device_count())):
|
|
if torch.cuda.is_available():
|
|
if len(deviceid) > 1:
|
|
gpu = range(len(deviceid))
|
|
return torch.nn.DataParallel(Obj, device_ids=gpu).cuda()
|
|
else:
|
|
return Obj.cuda()
|
|
else:
|
|
return Obj
|
|
|
|
def ConstructDataset(args):
|
|
sys.path.append('../dataset/')
|
|
m = __import__(args.dataset)
|
|
train_loader, val_loader = m.ConstructDataset(args)
|
|
return train_loader , val_loader
|
|
|
|
|
|
|
|
|
|
def parse():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--config', type=str, default='config.yml',
|
|
dest='config', help='to set the parameters')
|
|
return parser.parse_args()
|
|
|
|
def Config(filename):
|
|
with open(filename, 'r') as f:
|
|
parser = edict(yaml.load(f))
|
|
for x in parser:
|
|
print('{}: {}'.format(x, parser[x]))
|
|
return parser
|
|
|
|
def SetCUDAVISIBLEDEVICES(deviceid):
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(tuple(deviceid))[1:-1]
|
|
print("Use GPU IDs :" + str(tuple(deviceid))[1:-1])
|
|
|
|
|
|
|
|
|
|
def ConvKernelToImage(model, layer, foldname):
|
|
if not os.path.exists(foldname):
|
|
os.mkdir(foldname)
|
|
a2 = model.features[layer].weight.data
|
|
a2 = a2.cpu().detach().numpy().reshape((-1, a2.shape[-2], a2.shape[-1]))
|
|
for i in range(a2.shape[0]):
|
|
d = a2[i]
|
|
dmin = np.min(d)
|
|
dmax = np.max(d)
|
|
d = (d - dmin)*255.0/(dmax-dmin)
|
|
d = d.astype(int)
|
|
cv2.imwrite(foldname+"/"+str(i)+".png", d)
|
|
|
|
|
|
def NumpyToImage(numpydate, foldname, title="", maxImageWidth=128, maxImageHeight=128):
|
|
if not os.path.exists(foldname):
|
|
os.mkdir(foldname)
|
|
numpydatemin = np.min(numpydate)
|
|
numpydatemax = np.max(numpydate)
|
|
numpydate = (numpydate - numpydatemin)*255.0/(numpydatemax-numpydatemin)
|
|
data = numpydate.reshape((-1, numpydate.shape[-2], numpydate.shape[-1]))
|
|
datashape = data.shape
|
|
|
|
newdata = np.zeros((datashape[0],datashape[1]+1,datashape[2]+1))
|
|
newdata[:, 0:datashape[1], 0:datashape[2]]=data
|
|
datashape = newdata.shape
|
|
|
|
imagecols = int(maxImageWidth/datashape[2])
|
|
imagerows = int(maxImageHeight/datashape[1])
|
|
stepimages = imagecols*imagerows
|
|
for i in range(0,datashape[0],stepimages):
|
|
d = np.zeros((stepimages,datashape[1],datashape[2]))
|
|
left = newdata[i:min(i+stepimages,datashape[0])]
|
|
d[0:left.shape[0]]=left
|
|
d=np.reshape(d, (imagerows, imagecols, datashape[1], datashape[2]))
|
|
d=np.swapaxes(d, 1, 2)
|
|
d=np.reshape(d, (imagerows*datashape[1],imagecols*datashape[2]))
|
|
d = d.astype("uint8")
|
|
cv2.imwrite(foldname+"/"+title+str(i)+"-"+str(i+stepimages)+".png", d)
|