Refine show.
This commit is contained in:
parent
b860d794a6
commit
950055c210
|
@ -11,6 +11,8 @@ import os
|
||||||
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True, GridValue=0):
|
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True, GridValue=0):
|
||||||
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
||||||
raise ("Error input dims")
|
raise ("Error input dims")
|
||||||
|
if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}):
|
||||||
|
raise ("Error input name")
|
||||||
|
|
||||||
if len(tensor.shape) == 3:
|
if len(tensor.shape) == 3:
|
||||||
channel = tensor.shape[0]
|
channel = tensor.shape[0]
|
||||||
|
|
Loading…
Reference in New Issue