Update tools of show.
This commit is contained in:
parent
063f722177
commit
9386d044b6
|
@ -8,10 +8,22 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True):
|
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0):
|
||||||
if len(tensor.shape) != 2 and len(tensor.shape) != 1:
|
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
||||||
raise ("Error input dims")
|
raise ("Error input dims")
|
||||||
|
|
||||||
|
if len(tensor.shape) == 3:
|
||||||
|
channel = tensor.shape[0]
|
||||||
|
x = math.ceil((channel) ** 0.5)
|
||||||
|
tensor = F.pad(
|
||||||
|
tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0
|
||||||
|
)
|
||||||
|
tensor = tensor.reshape((x, x, tensor.shape[1], tensor.shape[2]))
|
||||||
|
tensor = tensor.permute((0, 2, 1, 3))
|
||||||
|
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
|
||||||
|
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale)
|
||||||
|
return
|
||||||
|
|
||||||
tensor = tensor.float()
|
tensor = tensor.float()
|
||||||
maxv = torch.max(tensor)
|
maxv = torch.max(tensor)
|
||||||
minv = torch.min(tensor)
|
minv = torch.min(tensor)
|
||||||
|
@ -19,12 +31,12 @@ def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True):
|
||||||
img = tensor.numpy()
|
img = tensor.numpy()
|
||||||
srp = img.shape
|
srp = img.shape
|
||||||
|
|
||||||
if auto2d and len(srp) == 1: # 1D的数据自动折叠成2D图像
|
if len(srp) == 1: # 1D的数据自动折叠成2D图像
|
||||||
ceiled = math.ceil((srp[0]) ** 0.5)
|
ceiled = math.ceil((srp[0]) ** 0.5)
|
||||||
img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0)
|
img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0)
|
||||||
img = img.reshape((ceiled, ceiled))
|
img = img.reshape((ceiled, ceiled))
|
||||||
srp = img.shape
|
srp = img.shape
|
||||||
if autoPad and (max(srp) / min(srp) > 16): # 自动拉伸成正方形
|
if forceSquare: # 拉伸成正方形
|
||||||
img = cv2.resize(img, [max(srp), max(srp)])
|
img = cv2.resize(img, [max(srp), max(srp)])
|
||||||
srp = img.shape
|
srp = img.shape
|
||||||
if scale != 1.0:
|
if scale != 1.0:
|
||||||
|
|
|
@ -9,6 +9,12 @@ import torch
|
||||||
radata = torch.randn(127)
|
radata = torch.randn(127)
|
||||||
show.DumpTensorToImage(radata, "test.png")
|
show.DumpTensorToImage(radata, "test.png")
|
||||||
|
|
||||||
|
|
||||||
|
radata = torch.randn(3,127,127)
|
||||||
|
show.DumpTensorToImage(radata, "test.png")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
radata = torch.randn(127, 127)
|
radata = torch.randn(127, 127)
|
||||||
show.DumpTensorToLog(radata, "test.log")
|
show.DumpTensorToLog(radata, "test.log")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue