Update show tools.

This commit is contained in:
Colin 2024-10-27 19:48:53 +08:00
parent a0da2565fe
commit 6a0b47c674
1 changed files with 7 additions and 1 deletions

View File

@ -9,11 +9,13 @@ import os
from pathlib import Path from pathlib import Path
def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None): def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None, Value2Log=False):
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
raise ("Error input dims") raise ("Error input dims")
if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}): if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}):
raise ("Error input name") raise ("Error input name")
if Value2Log:
DumpTensorToLog(tensor, name[:-4] + ".log")
if len(tensor.shape) == 3: if len(tensor.shape) == 3:
channel = tensor.shape[0] channel = tensor.shape[0]
@ -77,7 +79,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
def DumpTensorToLog(tensor, name="log"): def DumpTensorToLog(tensor, name="log"):
shape = tensor.shape shape = tensor.shape
tensor_mean = torch.mean(tensor).cpu().detach().numpy()
tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy()
f = open(name, "w") f = open(name, "w")
f.writelines("tensor mean: %s" % tensor_mean + os.linesep)
f.writelines("tensor range: %s" % tensor_range + os.linesep)
data = tensor.reshape([-1]).float().cpu().detach().numpy().tolist() data = tensor.reshape([-1]).float().cpu().detach().numpy().tolist()
for d in data: for d in data:
f.writelines("%s" % d + os.linesep) f.writelines("%s" % d + os.linesep)