Witllm/tools/show.py

120 lines
3.7 KiB
Python
Raw Normal View History

2023-12-21 19:52:19 +08:00
import plotly_express as px
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as Vision
import cv2
2023-12-21 21:20:49 +08:00
import math
import numpy as np
2023-12-25 22:53:53 +08:00
import os
2024-09-16 18:46:09 +08:00
from pathlib import Path
2023-12-21 19:52:19 +08:00
2024-09-02 17:52:33 +08:00
def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None):
2024-01-13 16:48:56 +08:00
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
2023-12-21 19:52:19 +08:00
raise ("Error input dims")
2024-08-18 16:38:13 +08:00
if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}):
raise ("Error input name")
2023-12-21 21:20:49 +08:00
2024-01-13 16:48:56 +08:00
if len(tensor.shape) == 3:
channel = tensor.shape[0]
x = math.ceil((channel) ** 0.5)
2024-08-29 16:47:06 +08:00
calc = tensor.reshape((channel, tensor.shape[1] * tensor.shape[2]))
if not Contrast:
2024-01-21 20:50:36 +08:00
tensormax = calc.max(1)[0]
tensormin = calc.min(1)[0]
2024-08-29 16:47:06 +08:00
else:
tensormax = Contrast[1]
tensormin = Contrast[0]
calc = calc.transpose(1, 0)
calc = ((calc - tensormin) / (tensormax - tensormin)) * 255.0
calc = calc.transpose(1, 0)
calc = calc.reshape((channel, tensor.shape[1], tensor.shape[2]))
if not GridValue:
GridValue = 128.0
calc = F.pad(calc, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=GridValue)
2024-01-22 20:57:27 +08:00
calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue)
tensor = calc.permute((0, 2, 1, 3))
2024-01-13 16:48:56 +08:00
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
2024-08-29 16:47:06 +08:00
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, Contrast=[0.0, 255.0], GridValue=GridValue)
2024-01-13 16:48:56 +08:00
return
2023-12-21 19:52:19 +08:00
tensor = tensor.float()
2024-08-29 16:47:06 +08:00
if not Contrast:
2024-01-21 20:50:36 +08:00
maxv = torch.max(tensor)
minv = torch.min(tensor)
2024-08-29 16:47:06 +08:00
else:
maxv = Contrast[1]
minv = Contrast[0]
tensor = ((tensor - minv) / (maxv - minv)) * 255.0
img = tensor.detach().cpu().numpy()
2023-12-21 19:52:19 +08:00
srp = img.shape
2023-12-21 21:20:49 +08:00
2024-01-13 16:48:56 +08:00
if len(srp) == 1: # 1D的数据自动折叠成2D图像
2023-12-21 21:20:49 +08:00
ceiled = math.ceil((srp[0]) ** 0.5)
img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0)
img = img.reshape((ceiled, ceiled))
srp = img.shape
2024-01-13 16:48:56 +08:00
if forceSquare: # 拉伸成正方形
2023-12-21 21:20:49 +08:00
img = cv2.resize(img, [max(srp), max(srp)])
srp = img.shape
2023-12-21 19:52:19 +08:00
if scale != 1.0:
img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
2023-12-21 21:20:49 +08:00
srp = img.shape
2024-08-29 16:47:06 +08:00
img = img * (-1)
img = img + 255
img[img < 0] = 0
img = np.nan_to_num(img, nan=0.0)
img[img > 255] = 255
imgs = img.astype(np.uint8)
imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_JET)
2024-09-16 18:46:09 +08:00
directory = Path(name).parent
if not directory.is_dir():
directory.mkdir(parents=True, exist_ok=True)
2024-08-29 16:47:06 +08:00
cv2.imwrite(name, imgs)
2023-12-25 22:53:53 +08:00
def DumpTensorToLog(tensor, name="log"):
shape = tensor.shape
f = open(name, "w")
2024-08-29 16:47:06 +08:00
data = tensor.reshape([-1]).float().cpu().detach().numpy().tolist()
2023-12-25 22:53:53 +08:00
for d in data:
f.writelines("%s" % d + os.linesep)
f.close()
2023-12-26 14:08:02 +08:00
2023-12-25 22:53:53 +08:00
def DumpTensorToFile(tensor, name="tensor.pt"):
2023-12-26 14:08:02 +08:00
torch.save(tensor.cpu(), name)
2023-12-25 22:53:53 +08:00
def LoadTensorToFile(name="tensor.pt"):
2023-12-26 14:08:02 +08:00
return torch.load(name)
def DumpListToFile(list, name="list"):
f = open(name, "w")
for d in list:
f.writelines("%s" % d + os.linesep)
f.close()
2023-12-29 19:55:53 +08:00
prob_true = 0
prob_all = 0
def ProbGE0(tensor: torch.tensor):
global prob_true
global prob_all
m = tensor.ge(0)
prob_true = prob_true + m.sum().item()
prob_all = prob_all + math.prod(tensor.shape)
def DumpProb():
global prob_true
global prob_all
print("prob_true : " + str(prob_true))
print("prob_all : " + str(prob_all))
print("prob : " + str((prob_true * 100) / prob_all) + "%")