Witllm/tools/show.py

24 lines
723 B
Python

import plotly_express as px
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as Vision
import cv2
def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0):
if len(tensor.shape) != 2:
raise ("Error input dims")
tensor = tensor.float()
maxv = torch.max(tensor)
minv = torch.min(tensor)
tensor = (((tensor - minv) / (maxv - minv)) * 256).byte().cpu()
img = tensor.numpy()
srp = img.shape
if autoPad and (max(srp) / min(srp) > 16):
img = cv2.resize(img,[max(srp),max(srp)])
srp = img.shape
if scale != 1.0:
img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
srp = img.shape
cv2.imwrite(name, img)