import plotly_express as px import torch import torch.nn.functional as F import torchvision.transforms.functional as Vision import cv2 import math import numpy as np import os def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None): if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: raise ("Error input dims") if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}): raise ("Error input name") if len(tensor.shape) == 3: channel = tensor.shape[0] x = math.ceil((channel) ** 0.5) calc = tensor.reshape((channel, tensor.shape[1] * tensor.shape[2])) if not Contrast: tensormax = calc.max(1)[0] tensormin = calc.min(1)[0] else: tensormax = Contrast[1] tensormin = Contrast[0] calc = calc.transpose(1, 0) calc = ((calc - tensormin) / (tensormax - tensormin)) * 255.0 calc = calc.transpose(1, 0) calc = calc.reshape((channel, tensor.shape[1], tensor.shape[2])) if not GridValue: GridValue = 128.0 calc = F.pad(calc, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=GridValue) calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2])) calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue) tensor = calc.permute((0, 2, 1, 3)) tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, Contrast=[0.0, 255.0], GridValue=GridValue) return tensor = tensor.float() if not Contrast: maxv = torch.max(tensor) minv = torch.min(tensor) else: maxv = Contrast[1] minv = Contrast[0] tensor = ((tensor - minv) / (maxv - minv)) * 255.0 img = tensor.detach().cpu().numpy() srp = img.shape if len(srp) == 1: # 1D的数据自动折叠成2D图像 ceiled = math.ceil((srp[0]) ** 0.5) img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0) img = img.reshape((ceiled, ceiled)) srp = img.shape if forceSquare: # 拉伸成正方形 img = cv2.resize(img, [max(srp), max(srp)]) srp = img.shape if scale != 1.0: img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)]) srp = img.shape img = img * (-1) img = img + 255 img[img < 0] = 0 img = np.nan_to_num(img, nan=0.0) img[img > 255] = 255 imgs = img.astype(np.uint8) imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_JET) cv2.imwrite(name, imgs) def DumpTensorToLog(tensor, name="log"): shape = tensor.shape f = open(name, "w") data = tensor.reshape([-1]).float().cpu().detach().numpy().tolist() for d in data: f.writelines("%s" % d + os.linesep) f.close() def DumpTensorToFile(tensor, name="tensor.pt"): torch.save(tensor.cpu(), name) def LoadTensorToFile(name="tensor.pt"): return torch.load(name) def DumpListToFile(list, name="list"): f = open(name, "w") for d in list: f.writelines("%s" % d + os.linesep) f.close() prob_true = 0 prob_all = 0 def ProbGE0(tensor: torch.tensor): global prob_true global prob_all m = tensor.ge(0) prob_true = prob_true + m.sum().item() prob_all = prob_all + math.prod(tensor.shape) def DumpProb(): global prob_true global prob_all print("prob_true : " + str(prob_true)) print("prob_all : " + str(prob_all)) print("prob : " + str((prob_true * 100) / prob_all) + "%")