From 539392c843cd6ef4f5344d7a65f06130c1f96aff Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 21 Dec 2023 21:20:49 +0800 Subject: [PATCH] Add auto2d. --- test.png | Bin 0 -> 217 bytes tools/show.py | 19 ++++++++++++++----- tools/test.py | 8 ++++++-- 3 files changed, 20 insertions(+), 7 deletions(-) create mode 100644 test.png diff --git a/test.png b/test.png new file mode 100644 index 0000000000000000000000000000000000000000..ba037b858abdf048bd2091cf26993e9288a2cc72 GIT binary patch literal 217 zcmeAS@N?(olHy`uVBq!ia0vp^JRr;gBp8b2n5}`-0#6sm5DwYo9{ZF-|G!NvIkPOD z&0ga6^eVZBypcC9Jvmj>c(LT!)7jCtx8D!iXO?aIxQW?*%8amU>%+tz1f|U}HU8Ou zX0~%s(3;7&B-HoUeydNOv1`MntxvUjsumrpX4bRfHjMf4{>+Qzx?hWyPkQw|^r%>{ z#Pcxe>cV{~v!oCFJ|xGTgLzZ literal 0 HcmV?d00001 diff --git a/tools/show.py b/tools/show.py index 3117c82..3547252 100644 --- a/tools/show.py +++ b/tools/show.py @@ -3,21 +3,30 @@ import torch import torch.nn.functional as F import torchvision.transforms.functional as Vision import cv2 +import math +import numpy as np -def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0): - if len(tensor.shape) != 2: +def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0, auto2d=True): + if len(tensor.shape) != 2 and len(tensor.shape) != 1: raise ("Error input dims") + tensor = tensor.float() maxv = torch.max(tensor) minv = torch.min(tensor) tensor = (((tensor - minv) / (maxv - minv)) * 256).byte().cpu() img = tensor.numpy() srp = img.shape + + if auto2d and len(srp) == 1: + ceiled = math.ceil((srp[0]) ** 0.5) + img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0) + img = img.reshape((ceiled, ceiled)) + srp = img.shape if autoPad and (max(srp) / min(srp) > 16): - img = cv2.resize(img,[max(srp),max(srp)]) - srp = img.shape + img = cv2.resize(img, [max(srp), max(srp)]) + srp = img.shape if scale != 1.0: img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)]) - srp = img.shape + srp = img.shape cv2.imwrite(name, img) diff --git a/tools/test.py b/tools/test.py index ef65f14..858d675 100644 --- a/tools/test.py +++ b/tools/test.py @@ -2,5 +2,9 @@ import show import torch -radata = torch.randn(8192, 128) -show.DumpTensorToImage(radata, "test.png", autoPad=True,scale=0.2) +# radata = torch.randn(8192, 128) +# show.DumpTensorToImage(radata, "test.png", autoPad=True,scale=0.2) + + +radata = torch.randn(127) +show.DumpTensorToImage(radata, "test.png")