2023-12-21 19:52:19 +08:00
|
|
|
import plotly_express as px
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torchvision.transforms.functional as Vision
|
|
|
|
import cv2
|
2023-12-21 21:20:49 +08:00
|
|
|
import math
|
|
|
|
import numpy as np
|
2023-12-25 22:53:53 +08:00
|
|
|
import os
|
2023-12-21 19:52:19 +08:00
|
|
|
|
|
|
|
|
2023-12-21 21:20:49 +08:00
|
|
|
def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True):
|
|
|
|
if len(tensor.shape) != 2 and len(tensor.shape) != 1:
|
2023-12-21 19:52:19 +08:00
|
|
|
raise ("Error input dims")
|
2023-12-21 21:20:49 +08:00
|
|
|
|
2023-12-21 19:52:19 +08:00
|
|
|
tensor = tensor.float()
|
|
|
|
maxv = torch.max(tensor)
|
|
|
|
minv = torch.min(tensor)
|
2023-12-25 22:53:53 +08:00
|
|
|
tensor = (((tensor - minv) / (maxv - minv)) * 255).byte().cpu()
|
2023-12-21 19:52:19 +08:00
|
|
|
img = tensor.numpy()
|
|
|
|
srp = img.shape
|
2023-12-21 21:20:49 +08:00
|
|
|
|
2023-12-26 18:59:28 +08:00
|
|
|
if auto2d and len(srp) == 1: # 1D的数据自动折叠成2D图像
|
2023-12-21 21:20:49 +08:00
|
|
|
ceiled = math.ceil((srp[0]) ** 0.5)
|
|
|
|
img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0)
|
|
|
|
img = img.reshape((ceiled, ceiled))
|
|
|
|
srp = img.shape
|
2023-12-26 18:59:28 +08:00
|
|
|
if autoPad and (max(srp) / min(srp) > 16): # 自动拉伸成正方形
|
2023-12-21 21:20:49 +08:00
|
|
|
img = cv2.resize(img, [max(srp), max(srp)])
|
|
|
|
srp = img.shape
|
2023-12-21 19:52:19 +08:00
|
|
|
if scale != 1.0:
|
|
|
|
img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
|
2023-12-21 21:20:49 +08:00
|
|
|
srp = img.shape
|
2023-12-21 19:52:19 +08:00
|
|
|
cv2.imwrite(name, img)
|
2023-12-25 22:53:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
def DumpTensorToLog(tensor, name="log"):
|
|
|
|
shape = tensor.shape
|
|
|
|
f = open(name, "w")
|
|
|
|
data = tensor.reshape([-1]).float().cpu().numpy().tolist()
|
|
|
|
for d in data:
|
|
|
|
f.writelines("%s" % d + os.linesep)
|
|
|
|
f.close()
|
|
|
|
|
2023-12-26 14:08:02 +08:00
|
|
|
|
2023-12-25 22:53:53 +08:00
|
|
|
def DumpTensorToFile(tensor, name="tensor.pt"):
|
2023-12-26 14:08:02 +08:00
|
|
|
torch.save(tensor.cpu(), name)
|
|
|
|
|
2023-12-25 22:53:53 +08:00
|
|
|
|
|
|
|
def LoadTensorToFile(name="tensor.pt"):
|
2023-12-26 14:08:02 +08:00
|
|
|
return torch.load(name)
|
|
|
|
|
|
|
|
|
|
|
|
def DumpListToFile(list, name="list"):
|
|
|
|
f = open(name, "w")
|
|
|
|
for d in list:
|
|
|
|
f.writelines("%s" % d + os.linesep)
|
|
|
|
f.close()
|
2023-12-29 19:55:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
prob_true = 0
|
|
|
|
prob_all = 0
|
|
|
|
|
|
|
|
|
|
|
|
def ProbGE0(tensor: torch.tensor):
|
|
|
|
global prob_true
|
|
|
|
global prob_all
|
|
|
|
m = tensor.ge(0)
|
|
|
|
prob_true = prob_true + m.sum().item()
|
|
|
|
prob_all = prob_all + math.prod(tensor.shape)
|
|
|
|
|
|
|
|
|
|
|
|
def DumpProb():
|
|
|
|
global prob_true
|
|
|
|
global prob_all
|
|
|
|
print("prob_true : " + str(prob_true))
|
|
|
|
print("prob_all : " + str(prob_all))
|
|
|
|
print("prob : " + str((prob_true * 100) / prob_all) + "%")
|