152 lines
4.7 KiB
Python
152 lines
4.7 KiB
Python
|
import sys
|
||
|
import math
|
||
|
import torch
|
||
|
import shutil
|
||
|
import time
|
||
|
import os
|
||
|
import random
|
||
|
from easydict import EasyDict as edict
|
||
|
import yaml
|
||
|
import numpy as np
|
||
|
import argparse
|
||
|
|
||
|
class AverageMeter(object):
|
||
|
""" Computes ans stores the average and current value"""
|
||
|
def __init__(self):
|
||
|
self.reset()
|
||
|
|
||
|
def reset(self):
|
||
|
self.val = 0.
|
||
|
self.avg = 0.
|
||
|
self.sum = 0.
|
||
|
self.count = 0
|
||
|
|
||
|
def update(self, val, n=1):
|
||
|
self.val = val
|
||
|
self.sum += val * n
|
||
|
self.count += n
|
||
|
self.avg = self.sum / self.count
|
||
|
|
||
|
def AdjustLearningRate(optimizermodule, iters, base_lr, policy_parameter, policy='step', multiple=[1]):
|
||
|
|
||
|
if policy == 'fixed':
|
||
|
lr = base_lr
|
||
|
elif policy == 'step':
|
||
|
lr = base_lr * (policy_parameter['gamma'] ** (iters // policy_parameter['step_size']))
|
||
|
elif policy == 'exp':
|
||
|
lr = base_lr * (policy_parameter['gamma'] ** iters)
|
||
|
elif policy == 'inv':
|
||
|
lr = base_lr * ((1 + policy_parameter['gamma'] * iters) ** (-policy_parameter['power']))
|
||
|
elif policy == 'multistep':
|
||
|
lr = base_lr
|
||
|
for stepvalue in policy_parameter['stepvalue']:
|
||
|
if iters >= stepvalue:
|
||
|
lr *= policy_parameter['gamma']
|
||
|
else:
|
||
|
break
|
||
|
elif policy == 'poly':
|
||
|
lr = base_lr * ((1 - iters * 1.0 / policy_parameter['max_iter']) ** policy_parameter['power'])
|
||
|
elif policy == 'sigmoid':
|
||
|
lr = base_lr * (1.0 / (1 + math.exp(-policy_parameter['gamma'] * (iters - policy_parameter['stepsize']))))
|
||
|
elif policy == 'multistep-poly':
|
||
|
lr = base_lr
|
||
|
stepstart = 0
|
||
|
stepend = policy_parameter['max_iter']
|
||
|
for stepvalue in policy_parameter['stepvalue']:
|
||
|
if iters >= stepvalue:
|
||
|
lr *= policy_parameter['gamma']
|
||
|
stepstart = stepvalue
|
||
|
else:
|
||
|
stepend = stepvalue
|
||
|
break
|
||
|
lr = max(lr * policy_parameter['gamma'], lr * (1 - (iters - stepstart) * 1.0 / (stepend - stepstart)) ** policy_parameter['power'])
|
||
|
|
||
|
for i, param_group in enumerate(optimizermodule.param_groups):
|
||
|
param_group['lr'] = lr * multiple[i]
|
||
|
return lr
|
||
|
|
||
|
def ConstructModel(args):
|
||
|
sys.path.append('../network/')
|
||
|
m = __import__(args.modelname)
|
||
|
model = m.ConstructModel(args)
|
||
|
return model
|
||
|
|
||
|
|
||
|
def GetCheckpoint(filename):
|
||
|
state_dict = torch.load(filename, map_location='cpu')['state_dict']
|
||
|
from collections import OrderedDict
|
||
|
new_state_dict = OrderedDict()
|
||
|
for k, v in state_dict.items():
|
||
|
name = k
|
||
|
if k[0:7] == "module.":
|
||
|
name = k[7:]
|
||
|
new_state_dict[name] = v
|
||
|
return new_state_dict
|
||
|
def SaveCheckpoint(model, filename='checkpoint_'+str(time.time()) + '.pth.tar', is_best=False):
|
||
|
state = model.state_dict
|
||
|
torch.save(state, filename)
|
||
|
if is_best:
|
||
|
shutil.copyfile(filename, "best_"+filename)
|
||
|
def LoadCheckpoint(model , pretrained):
|
||
|
if pretrained != 'None' and os.path.exists(pretrained):
|
||
|
model.load_state_dict(GetCheckpoint(pretrained))
|
||
|
return model
|
||
|
else:
|
||
|
return ""
|
||
|
|
||
|
|
||
|
def SaveModel(model , filename='checkpoint_'+str(time.time()) + '.pkl'):
|
||
|
torch.save({
|
||
|
# 'epoch': epoch,
|
||
|
'model_state_dict': model.state_dict(),
|
||
|
# 'optimizer_state_dict': optimizer.state_dict(),
|
||
|
# 'loss': loss,
|
||
|
}, filename)
|
||
|
def LoadModel(model, filename):
|
||
|
checkpoint = torch.load(filename)
|
||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||
|
# epoch = checkpoint['epoch']
|
||
|
# loss = checkpoint['loss']
|
||
|
return model
|
||
|
|
||
|
|
||
|
def SetDevice(Obj , deviceid=range(torch.cuda.device_count())):
|
||
|
if torch.cuda.is_available():
|
||
|
if len(deviceid) > 1:
|
||
|
gpu = range(len(deviceid))
|
||
|
return torch.nn.DataParallel(Obj, device_ids=gpu).cuda()
|
||
|
else:
|
||
|
return Obj.cuda()
|
||
|
else:
|
||
|
return Obj
|
||
|
|
||
|
def ConstructDataset(args):
|
||
|
sys.path.append('../dataset/')
|
||
|
m = __import__(args.dataset)
|
||
|
train_loader, val_loader = m.ConstructDataset(args)
|
||
|
return train_loader , val_loader
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
def parse():
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--config', type=str, default='config.yml',
|
||
|
dest='config', help='to set the parameters')
|
||
|
return parser.parse_args()
|
||
|
|
||
|
def Config(filename):
|
||
|
with open(filename, 'r') as f:
|
||
|
parser = edict(yaml.load(f))
|
||
|
for x in parser:
|
||
|
print('{}: {}'.format(x, parser[x]))
|
||
|
return parser
|
||
|
|
||
|
def SetCUDAVISIBLEDEVICES(deviceid):
|
||
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(tuple(deviceid))[1:-1]
|
||
|
print("Use GPU IDs :" + str(tuple(deviceid))[1:-1])
|
||
|
|
||
|
|