diff --git a/tools/show.py b/tools/show.py index 9ac7bcf..2548bf7 100644 --- a/tools/show.py +++ b/tools/show.py @@ -11,6 +11,8 @@ import os def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True, GridValue=0): if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: raise ("Error input dims") + if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}): + raise ("Error input name") if len(tensor.shape) == 3: channel = tensor.shape[0]