96 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			96 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
import plotly_express as px
 | 
						|
import torch
 | 
						|
import torch.nn.functional as F
 | 
						|
import torchvision.transforms.functional as Vision
 | 
						|
import cv2
 | 
						|
import math
 | 
						|
import numpy as np
 | 
						|
import os
 | 
						|
 | 
						|
 | 
						|
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True):
 | 
						|
    if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
 | 
						|
        raise ("Error input dims")
 | 
						|
 | 
						|
    if len(tensor.shape) == 3:
 | 
						|
        channel = tensor.shape[0]
 | 
						|
        x = math.ceil((channel) ** 0.5)
 | 
						|
        tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0)
 | 
						|
        if AutoContrast:
 | 
						|
            calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
 | 
						|
            tensormax = calc.max(1)[0]
 | 
						|
            tensormin = calc.min(1)[0]
 | 
						|
            calc = calc.transpose(1, 0)
 | 
						|
            calc = ((calc - tensormin) / (tensormax - tensormin)) * 255
 | 
						|
            calc = calc.transpose(1, 0)
 | 
						|
        tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
 | 
						|
        tensor = tensor.permute((0, 2, 1, 3))
 | 
						|
        tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
 | 
						|
        DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False)
 | 
						|
        return
 | 
						|
 | 
						|
    tensor = tensor.float()
 | 
						|
    if AutoContrast:
 | 
						|
        maxv = torch.max(tensor)
 | 
						|
        minv = torch.min(tensor)
 | 
						|
        tensor = ((tensor - minv) / (maxv - minv)) * 255
 | 
						|
    img = tensor.byte().cpu().numpy()
 | 
						|
    srp = img.shape
 | 
						|
 | 
						|
    if len(srp) == 1:  # 1D的数据自动折叠成2D图像
 | 
						|
        ceiled = math.ceil((srp[0]) ** 0.5)
 | 
						|
        img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0)
 | 
						|
        img = img.reshape((ceiled, ceiled))
 | 
						|
        srp = img.shape
 | 
						|
    if forceSquare:  # 拉伸成正方形
 | 
						|
        img = cv2.resize(img, [max(srp), max(srp)])
 | 
						|
        srp = img.shape
 | 
						|
    if scale != 1.0:
 | 
						|
        img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
 | 
						|
        srp = img.shape
 | 
						|
    cv2.imwrite(name, img)
 | 
						|
 | 
						|
 | 
						|
def DumpTensorToLog(tensor, name="log"):
 | 
						|
    shape = tensor.shape
 | 
						|
    f = open(name, "w")
 | 
						|
    data = tensor.reshape([-1]).float().cpu().numpy().tolist()
 | 
						|
    for d in data:
 | 
						|
        f.writelines("%s" % d + os.linesep)
 | 
						|
    f.close()
 | 
						|
 | 
						|
 | 
						|
def DumpTensorToFile(tensor, name="tensor.pt"):
 | 
						|
    torch.save(tensor.cpu(), name)
 | 
						|
 | 
						|
 | 
						|
def LoadTensorToFile(name="tensor.pt"):
 | 
						|
    return torch.load(name)
 | 
						|
 | 
						|
 | 
						|
def DumpListToFile(list, name="list"):
 | 
						|
    f = open(name, "w")
 | 
						|
    for d in list:
 | 
						|
        f.writelines("%s" % d + os.linesep)
 | 
						|
    f.close()
 | 
						|
 | 
						|
 | 
						|
prob_true = 0
 | 
						|
prob_all = 0
 | 
						|
 | 
						|
 | 
						|
def ProbGE0(tensor: torch.tensor):
 | 
						|
    global prob_true
 | 
						|
    global prob_all
 | 
						|
    m = tensor.ge(0)
 | 
						|
    prob_true = prob_true + m.sum().item()
 | 
						|
    prob_all = prob_all + math.prod(tensor.shape)
 | 
						|
 | 
						|
 | 
						|
def DumpProb():
 | 
						|
    global prob_true
 | 
						|
    global prob_all
 | 
						|
    print("prob_true : " + str(prob_true))
 | 
						|
    print("prob_all : " + str(prob_all))
 | 
						|
    print("prob : " + str((prob_true * 100) / prob_all) + "%")
 |