diff --git a/tools/show.py b/tools/show.py index 4fc731f..3eb4183 100644 --- a/tools/show.py +++ b/tools/show.py @@ -8,10 +8,22 @@ import numpy as np import os -def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True): - if len(tensor.shape) != 2 and len(tensor.shape) != 1: +def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0): + if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: raise ("Error input dims") + if len(tensor.shape) == 3: + channel = tensor.shape[0] + x = math.ceil((channel) ** 0.5) + tensor = F.pad( + tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0 + ) + tensor = tensor.reshape((x, x, tensor.shape[1], tensor.shape[2])) + tensor = tensor.permute((0, 2, 1, 3)) + tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) + DumpTensorToImage(tensor, name, forceSquare=False, scale=scale) + return + tensor = tensor.float() maxv = torch.max(tensor) minv = torch.min(tensor) @@ -19,12 +31,12 @@ def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True): img = tensor.numpy() srp = img.shape - if auto2d and len(srp) == 1: # 1D的数据自动折叠成2D图像 + if len(srp) == 1: # 1D的数据自动折叠成2D图像 ceiled = math.ceil((srp[0]) ** 0.5) img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0) img = img.reshape((ceiled, ceiled)) srp = img.shape - if autoPad and (max(srp) / min(srp) > 16): # 自动拉伸成正方形 + if forceSquare: # 拉伸成正方形 img = cv2.resize(img, [max(srp), max(srp)]) srp = img.shape if scale != 1.0: diff --git a/tools/test.py b/tools/test.py index 0111b7b..a97d099 100644 --- a/tools/test.py +++ b/tools/test.py @@ -9,6 +9,12 @@ import torch radata = torch.randn(127) show.DumpTensorToImage(radata, "test.png") + +radata = torch.randn(3,127,127) +show.DumpTensorToImage(radata, "test.png") + + + radata = torch.randn(127, 127) show.DumpTensorToLog(radata, "test.log")