Update tools of show.
This commit is contained in:
		
							parent
							
								
									063f722177
								
							
						
					
					
						commit
						9386d044b6
					
				| 
						 | 
				
			
			@ -8,10 +8,22 @@ import numpy as np
 | 
			
		|||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True):
 | 
			
		||||
    if len(tensor.shape) != 2 and len(tensor.shape) != 1:
 | 
			
		||||
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0):
 | 
			
		||||
    if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
 | 
			
		||||
        raise ("Error input dims")
 | 
			
		||||
 | 
			
		||||
    if len(tensor.shape) == 3:
 | 
			
		||||
        channel = tensor.shape[0]
 | 
			
		||||
        x = math.ceil((channel) ** 0.5)
 | 
			
		||||
        tensor = F.pad(
 | 
			
		||||
            tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0
 | 
			
		||||
        )
 | 
			
		||||
        tensor = tensor.reshape((x, x, tensor.shape[1], tensor.shape[2]))
 | 
			
		||||
        tensor = tensor.permute((0, 2, 1, 3))
 | 
			
		||||
        tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
 | 
			
		||||
        DumpTensorToImage(tensor, name, forceSquare=False, scale=scale)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    tensor = tensor.float()
 | 
			
		||||
    maxv = torch.max(tensor)
 | 
			
		||||
    minv = torch.min(tensor)
 | 
			
		||||
| 
						 | 
				
			
			@ -19,12 +31,12 @@ def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True):
 | 
			
		|||
    img = tensor.numpy()
 | 
			
		||||
    srp = img.shape
 | 
			
		||||
 | 
			
		||||
    if auto2d and len(srp) == 1:  # 1D的数据自动折叠成2D图像
 | 
			
		||||
    if len(srp) == 1:  # 1D的数据自动折叠成2D图像
 | 
			
		||||
        ceiled = math.ceil((srp[0]) ** 0.5)
 | 
			
		||||
        img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0)
 | 
			
		||||
        img = img.reshape((ceiled, ceiled))
 | 
			
		||||
        srp = img.shape
 | 
			
		||||
    if autoPad and (max(srp) / min(srp) > 16):  # 自动拉伸成正方形
 | 
			
		||||
    if forceSquare:  # 拉伸成正方形
 | 
			
		||||
        img = cv2.resize(img, [max(srp), max(srp)])
 | 
			
		||||
        srp = img.shape
 | 
			
		||||
    if scale != 1.0:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -9,6 +9,12 @@ import torch
 | 
			
		|||
radata = torch.randn(127)
 | 
			
		||||
show.DumpTensorToImage(radata, "test.png")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
radata = torch.randn(3,127,127)
 | 
			
		||||
show.DumpTensorToImage(radata, "test.png")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
radata = torch.randn(127, 127)
 | 
			
		||||
show.DumpTensorToLog(radata, "test.log")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue