import plotly_express as px import torch import torch.nn.functional as F import torchvision.transforms.functional as Vision import cv2 import math import numpy as np import os def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True): if len(tensor.shape) != 2 and len(tensor.shape) != 1: raise ("Error input dims") tensor = tensor.float() maxv = torch.max(tensor) minv = torch.min(tensor) tensor = (((tensor - minv) / (maxv - minv)) * 255).byte().cpu() img = tensor.numpy() srp = img.shape if auto2d and len(srp) == 1: # 1D的数据自动折叠成2D图像 ceiled = math.ceil((srp[0]) ** 0.5) img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0) img = img.reshape((ceiled, ceiled)) srp = img.shape if autoPad and (max(srp) / min(srp) > 16): # 自动拉伸成正方形 img = cv2.resize(img, [max(srp), max(srp)]) srp = img.shape if scale != 1.0: img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)]) srp = img.shape cv2.imwrite(name, img) def DumpTensorToLog(tensor, name="log"): shape = tensor.shape f = open(name, "w") data = tensor.reshape([-1]).float().cpu().numpy().tolist() for d in data: f.writelines("%s" % d + os.linesep) f.close() def DumpTensorToFile(tensor, name="tensor.pt"): torch.save(tensor.cpu(), name) def LoadTensorToFile(name="tensor.pt"): return torch.load(name) def DumpListToFile(list, name="list"): f = open(name, "w") for d in list: f.writelines("%s" % d + os.linesep) f.close()