import plotly_express as px import torch import torch.nn.functional as F import torchvision.transforms.functional as Vision import cv2 def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0): if len(tensor.shape) != 2: raise ("Error input dims") tensor = tensor.float() maxv = torch.max(tensor) minv = torch.min(tensor) tensor = (((tensor - minv) / (maxv - minv)) * 256).byte().cpu() img = tensor.numpy() srp = img.shape if autoPad and (max(srp) / min(srp) > 16): img = cv2.resize(img,[max(srp),max(srp)]) srp = img.shape if scale != 1.0: img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)]) srp = img.shape cv2.imwrite(name, img)