Witllm/tools/show.py

58 lines
1.6 KiB
Python

import plotly_express as px
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as Vision
import cv2
import math
import numpy as np
import os
def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True):
if len(tensor.shape) != 2 and len(tensor.shape) != 1:
raise ("Error input dims")
tensor = tensor.float()
maxv = torch.max(tensor)
minv = torch.min(tensor)
tensor = (((tensor - minv) / (maxv - minv)) * 255).byte().cpu()
img = tensor.numpy()
srp = img.shape
if auto2d and len(srp) == 1: # 1D的数据自动折叠成2D图像
ceiled = math.ceil((srp[0]) ** 0.5)
img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0)
img = img.reshape((ceiled, ceiled))
srp = img.shape
if autoPad and (max(srp) / min(srp) > 16): # 自动拉伸成正方形
img = cv2.resize(img, [max(srp), max(srp)])
srp = img.shape
if scale != 1.0:
img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
srp = img.shape
cv2.imwrite(name, img)
def DumpTensorToLog(tensor, name="log"):
shape = tensor.shape
f = open(name, "w")
data = tensor.reshape([-1]).float().cpu().numpy().tolist()
for d in data:
f.writelines("%s" % d + os.linesep)
f.close()
def DumpTensorToFile(tensor, name="tensor.pt"):
torch.save(tensor.cpu(), name)
def LoadTensorToFile(name="tensor.pt"):
return torch.load(name)
def DumpListToFile(list, name="list"):
f = open(name, "w")
for d in list:
f.writelines("%s" % d + os.linesep)
f.close()