diff --git a/tools/show.py b/tools/show.py index e7b3370..a9d2887 100644 --- a/tools/show.py +++ b/tools/show.py @@ -9,11 +9,13 @@ import os from pathlib import Path -def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None): +def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None, Value2Log=False): if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: raise ("Error input dims") if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}): raise ("Error input name") + if Value2Log: + DumpTensorToLog(tensor, name[:-4] + ".log") if len(tensor.shape) == 3: channel = tensor.shape[0] @@ -77,7 +79,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, def DumpTensorToLog(tensor, name="log"): shape = tensor.shape + tensor_mean = torch.mean(tensor).cpu().detach().numpy() + tensor_range = (torch.max(tensor) - torch.min(tensor)).cpu().detach().numpy() f = open(name, "w") + f.writelines("tensor mean: %s" % tensor_mean + os.linesep) + f.writelines("tensor range: %s" % tensor_range + os.linesep) data = tensor.reshape([-1]).float().cpu().detach().numpy().tolist() for d in data: f.writelines("%s" % d + os.linesep)