diff --git a/test.png b/test.png new file mode 100644 index 0000000..ba037b8 Binary files /dev/null and b/test.png differ diff --git a/tools/show.py b/tools/show.py index 3117c82..3547252 100644 --- a/tools/show.py +++ b/tools/show.py @@ -3,21 +3,30 @@ import torch import torch.nn.functional as F import torchvision.transforms.functional as Vision import cv2 +import math +import numpy as np -def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0): - if len(tensor.shape) != 2: +def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True): + if len(tensor.shape) != 2 and len(tensor.shape) != 1: raise ("Error input dims") + tensor = tensor.float() maxv = torch.max(tensor) minv = torch.min(tensor) tensor = (((tensor - minv) / (maxv - minv)) * 256).byte().cpu() img = tensor.numpy() srp = img.shape + + if auto2d and len(srp) == 1: + ceiled = math.ceil((srp[0]) ** 0.5) + img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0) + img = img.reshape((ceiled, ceiled)) + srp = img.shape if autoPad and (max(srp) / min(srp) > 16): - img = cv2.resize(img,[max(srp),max(srp)]) - srp = img.shape + img = cv2.resize(img, [max(srp), max(srp)]) + srp = img.shape if scale != 1.0: img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)]) - srp = img.shape + srp = img.shape cv2.imwrite(name, img) diff --git a/tools/test.py b/tools/test.py index ef65f14..858d675 100644 --- a/tools/test.py +++ b/tools/test.py @@ -2,5 +2,9 @@ import show import torch -radata = torch.randn(8192, 128) -show.DumpTensorToImage(radata, "test.png", autoPad=True,scale=0.2) +# radata = torch.randn(8192, 128) +# show.DumpTensorToImage(radata, "test.png", autoPad=True,scale=0.2) + + +radata = torch.randn(127) +show.DumpTensorToImage(radata, "test.png")