witnn/tools/utils.py

152 lines
4.7 KiB
Python
Executable File

import sys
import math
import torch
import shutil
import time
import os
import random
from easydict import EasyDict as edict
import yaml
import numpy as np
import argparse
class AverageMeter(object):
""" Computes ans stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0.
self.avg = 0.
self.sum = 0.
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def AdjustLearningRate(optimizermodule, iters, base_lr, policy_parameter, policy='step', multiple=[1]):
if policy == 'fixed':
lr = base_lr
elif policy == 'step':
lr = base_lr * (policy_parameter['gamma'] ** (iters // policy_parameter['step_size']))
elif policy == 'exp':
lr = base_lr * (policy_parameter['gamma'] ** iters)
elif policy == 'inv':
lr = base_lr * ((1 + policy_parameter['gamma'] * iters) ** (-policy_parameter['power']))
elif policy == 'multistep':
lr = base_lr
for stepvalue in policy_parameter['stepvalue']:
if iters >= stepvalue:
lr *= policy_parameter['gamma']
else:
break
elif policy == 'poly':
lr = base_lr * ((1 - iters * 1.0 / policy_parameter['max_iter']) ** policy_parameter['power'])
elif policy == 'sigmoid':
lr = base_lr * (1.0 / (1 + math.exp(-policy_parameter['gamma'] * (iters - policy_parameter['stepsize']))))
elif policy == 'multistep-poly':
lr = base_lr
stepstart = 0
stepend = policy_parameter['max_iter']
for stepvalue in policy_parameter['stepvalue']:
if iters >= stepvalue:
lr *= policy_parameter['gamma']
stepstart = stepvalue
else:
stepend = stepvalue
break
lr = max(lr * policy_parameter['gamma'], lr * (1 - (iters - stepstart) * 1.0 / (stepend - stepstart)) ** policy_parameter['power'])
for i, param_group in enumerate(optimizermodule.param_groups):
param_group['lr'] = lr * multiple[i]
return lr
def ConstructModel(args):
sys.path.append('../network/')
m = __import__(args.modelname)
model = m.ConstructModel(args)
return model
def GetCheckpoint(filename):
state_dict = torch.load(filename, map_location='cpu')['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k
if k[0:7] == "module.":
name = k[7:]
new_state_dict[name] = v
return new_state_dict
def SaveCheckpoint(model, filename='checkpoint_'+str(time.time()) + '.pth.tar', is_best=False):
state = model.state_dict
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "best_"+filename)
def LoadCheckpoint(model , pretrained):
if pretrained != 'None' and os.path.exists(pretrained):
model.load_state_dict(GetCheckpoint(pretrained))
return model
else:
return ""
def SaveModel(model , filename='checkpoint_'+str(time.time()) + '.pkl'):
torch.save({
# 'epoch': epoch,
'model_state_dict': model.state_dict(),
# 'optimizer_state_dict': optimizer.state_dict(),
# 'loss': loss,
}, filename)
def LoadModel(model, filename):
checkpoint = torch.load(filename)
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']
return model
def SetDevice(Obj , deviceid=range(torch.cuda.device_count())):
if torch.cuda.is_available():
if len(deviceid) > 1:
gpu = range(len(deviceid))
return torch.nn.DataParallel(Obj, device_ids=gpu).cuda()
else:
return Obj.cuda()
else:
return Obj
def ConstructDataset(args):
sys.path.append('../dataset/')
m = __import__(args.dataset)
train_loader, val_loader = m.ConstructDataset(args)
return train_loader , val_loader
def parse():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='config.yml',
dest='config', help='to set the parameters')
return parser.parse_args()
def Config(filename):
with open(filename, 'r') as f:
parser = edict(yaml.load(f))
for x in parser:
print('{}: {}'.format(x, parser[x]))
return parser
def SetCUDAVISIBLEDEVICES(deviceid):
os.environ["CUDA_VISIBLE_DEVICES"] = str(tuple(deviceid))[1:-1]
print("Use GPU IDs :" + str(tuple(deviceid))[1:-1])