import plotly_express as px import torch import torch.nn.functional as F import torchvision.transforms.functional as Vision import cv2 import math import numpy as np import os def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0): if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: raise ("Error input dims") if len(tensor.shape) == 3: channel = tensor.shape[0] x = math.ceil((channel) ** 0.5) tensor = F.pad( tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0 ) tensor = tensor.reshape((x, x, tensor.shape[1], tensor.shape[2])) tensor = tensor.permute((0, 2, 1, 3)) tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) DumpTensorToImage(tensor, name, forceSquare=False, scale=scale) return tensor = tensor.float() maxv = torch.max(tensor) minv = torch.min(tensor) tensor = (((tensor - minv) / (maxv - minv)) * 255).byte().cpu() img = tensor.numpy() srp = img.shape if len(srp) == 1: # 1D的数据自动折叠成2D图像 ceiled = math.ceil((srp[0]) ** 0.5) img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0) img = img.reshape((ceiled, ceiled)) srp = img.shape if forceSquare: # 拉伸成正方形 img = cv2.resize(img, [max(srp), max(srp)]) srp = img.shape if scale != 1.0: img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)]) srp = img.shape cv2.imwrite(name, img) def DumpTensorToLog(tensor, name="log"): shape = tensor.shape f = open(name, "w") data = tensor.reshape([-1]).float().cpu().numpy().tolist() for d in data: f.writelines("%s" % d + os.linesep) f.close() def DumpTensorToFile(tensor, name="tensor.pt"): torch.save(tensor.cpu(), name) def LoadTensorToFile(name="tensor.pt"): return torch.load(name) def DumpListToFile(list, name="list"): f = open(name, "w") for d in list: f.writelines("%s" % d + os.linesep) f.close() prob_true = 0 prob_all = 0 def ProbGE0(tensor: torch.tensor): global prob_true global prob_all m = tensor.ge(0) prob_true = prob_true + m.sum().item() prob_all = prob_all + math.prod(tensor.shape) def DumpProb(): global prob_true global prob_all print("prob_true : " + str(prob_true)) print("prob_all : " + str(prob_all)) print("prob : " + str((prob_true * 100) / prob_all) + "%")