58 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			58 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
	
| import plotly_express as px
 | |
| import torch
 | |
| import torch.nn.functional as F
 | |
| import torchvision.transforms.functional as Vision
 | |
| import cv2
 | |
| import math
 | |
| import numpy as np
 | |
| import os
 | |
| 
 | |
| 
 | |
| def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True):
 | |
|     if len(tensor.shape) != 2 and len(tensor.shape) != 1:
 | |
|         raise ("Error input dims")
 | |
| 
 | |
|     tensor = tensor.float()
 | |
|     maxv = torch.max(tensor)
 | |
|     minv = torch.min(tensor)
 | |
|     tensor = (((tensor - minv) / (maxv - minv)) * 255).byte().cpu()
 | |
|     img = tensor.numpy()
 | |
|     srp = img.shape
 | |
| 
 | |
|     if auto2d and len(srp) == 1:  # 1D的数据自动折叠成2D图像
 | |
|         ceiled = math.ceil((srp[0]) ** 0.5)
 | |
|         img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0)
 | |
|         img = img.reshape((ceiled, ceiled))
 | |
|         srp = img.shape
 | |
|     if autoPad and (max(srp) / min(srp) > 16):  # 自动拉伸成正方形
 | |
|         img = cv2.resize(img, [max(srp), max(srp)])
 | |
|         srp = img.shape
 | |
|     if scale != 1.0:
 | |
|         img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
 | |
|         srp = img.shape
 | |
|     cv2.imwrite(name, img)
 | |
| 
 | |
| 
 | |
| def DumpTensorToLog(tensor, name="log"):
 | |
|     shape = tensor.shape
 | |
|     f = open(name, "w")
 | |
|     data = tensor.reshape([-1]).float().cpu().numpy().tolist()
 | |
|     for d in data:
 | |
|         f.writelines("%s" % d + os.linesep)
 | |
|     f.close()
 | |
| 
 | |
| 
 | |
| def DumpTensorToFile(tensor, name="tensor.pt"):
 | |
|     torch.save(tensor.cpu(), name)
 | |
| 
 | |
| 
 | |
| def LoadTensorToFile(name="tensor.pt"):
 | |
|     return torch.load(name)
 | |
| 
 | |
| 
 | |
| def DumpListToFile(list, name="list"):
 | |
|     f = open(name, "w")
 | |
|     for d in list:
 | |
|         f.writelines("%s" % d + os.linesep)
 | |
|     f.close()
 |