| 
									
										
										
										
											2023-12-21 19:52:19 +08:00
										 |  |  | import plotly_express as px | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torch.nn.functional as F | 
					
						
							|  |  |  | import torchvision.transforms.functional as Vision | 
					
						
							|  |  |  | import cv2 | 
					
						
							| 
									
										
										
										
											2023-12-21 21:20:49 +08:00
										 |  |  | import math | 
					
						
							|  |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2023-12-25 22:53:53 +08:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2024-09-16 18:46:09 +08:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2023-12-21 19:52:19 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-02 17:52:33 +08:00
										 |  |  | def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, GridValue=None): | 
					
						
							| 
									
										
										
										
											2024-01-13 16:48:56 +08:00
										 |  |  |     if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: | 
					
						
							| 
									
										
										
										
											2023-12-21 19:52:19 +08:00
										 |  |  |         raise ("Error input dims") | 
					
						
							| 
									
										
										
										
											2024-08-18 16:38:13 +08:00
										 |  |  |     if ("." not in name) or (name.split(".")[-1] not in {"jpg", "png", "bmp"}): | 
					
						
							|  |  |  |         raise ("Error input name") | 
					
						
							| 
									
										
										
										
											2023-12-21 21:20:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-13 16:48:56 +08:00
										 |  |  |     if len(tensor.shape) == 3: | 
					
						
							|  |  |  |         channel = tensor.shape[0] | 
					
						
							|  |  |  |         x = math.ceil((channel) ** 0.5) | 
					
						
							| 
									
										
										
										
											2024-08-29 16:47:06 +08:00
										 |  |  |         calc = tensor.reshape((channel, tensor.shape[1] * tensor.shape[2])) | 
					
						
							|  |  |  |         if not Contrast: | 
					
						
							| 
									
										
										
										
											2024-01-21 20:50:36 +08:00
										 |  |  |             tensormax = calc.max(1)[0] | 
					
						
							|  |  |  |             tensormin = calc.min(1)[0] | 
					
						
							| 
									
										
										
										
											2024-08-29 16:47:06 +08:00
										 |  |  |         else: | 
					
						
							|  |  |  |             tensormax = Contrast[1] | 
					
						
							|  |  |  |             tensormin = Contrast[0] | 
					
						
							|  |  |  |         calc = calc.transpose(1, 0) | 
					
						
							|  |  |  |         calc = ((calc - tensormin) / (tensormax - tensormin)) * 255.0 | 
					
						
							|  |  |  |         calc = calc.transpose(1, 0) | 
					
						
							|  |  |  |         calc = calc.reshape((channel, tensor.shape[1], tensor.shape[2])) | 
					
						
							|  |  |  |         if not GridValue: | 
					
						
							|  |  |  |             GridValue = 128.0 | 
					
						
							|  |  |  |         calc = F.pad(calc, (0, 0, 0, 0, 0, x * x - channel), mode="constant", value=GridValue) | 
					
						
							| 
									
										
										
										
											2024-01-22 20:57:27 +08:00
										 |  |  |         calc = calc.reshape((x, x, tensor.shape[1], tensor.shape[2])) | 
					
						
							|  |  |  |         calc = F.pad(calc, (0, 1, 0, 1, 0, 0), mode="constant", value=GridValue) | 
					
						
							|  |  |  |         tensor = calc.permute((0, 2, 1, 3)) | 
					
						
							| 
									
										
										
										
											2024-01-13 16:48:56 +08:00
										 |  |  |         tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) | 
					
						
							| 
									
										
										
										
											2024-08-29 16:47:06 +08:00
										 |  |  |         DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, Contrast=[0.0, 255.0], GridValue=GridValue) | 
					
						
							| 
									
										
										
										
											2024-01-13 16:48:56 +08:00
										 |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-21 19:52:19 +08:00
										 |  |  |     tensor = tensor.float() | 
					
						
							| 
									
										
										
										
											2024-08-29 16:47:06 +08:00
										 |  |  |     if not Contrast: | 
					
						
							| 
									
										
										
										
											2024-01-21 20:50:36 +08:00
										 |  |  |         maxv = torch.max(tensor) | 
					
						
							|  |  |  |         minv = torch.min(tensor) | 
					
						
							| 
									
										
										
										
											2024-08-29 16:47:06 +08:00
										 |  |  |     else: | 
					
						
							|  |  |  |         maxv = Contrast[1] | 
					
						
							|  |  |  |         minv = Contrast[0] | 
					
						
							|  |  |  |     tensor = ((tensor - minv) / (maxv - minv)) * 255.0 | 
					
						
							|  |  |  |     img = tensor.detach().cpu().numpy() | 
					
						
							| 
									
										
										
										
											2023-12-21 19:52:19 +08:00
										 |  |  |     srp = img.shape | 
					
						
							| 
									
										
										
										
											2023-12-21 21:20:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-13 16:48:56 +08:00
										 |  |  |     if len(srp) == 1:  # 1D的数据自动折叠成2D图像 | 
					
						
							| 
									
										
										
										
											2023-12-21 21:20:49 +08:00
										 |  |  |         ceiled = math.ceil((srp[0]) ** 0.5) | 
					
						
							|  |  |  |         img = cv2.copyMakeBorder(img, 0, ceiled * ceiled - srp[0], 0, 0, 0) | 
					
						
							|  |  |  |         img = img.reshape((ceiled, ceiled)) | 
					
						
							|  |  |  |         srp = img.shape | 
					
						
							| 
									
										
										
										
											2024-01-13 16:48:56 +08:00
										 |  |  |     if forceSquare:  # 拉伸成正方形 | 
					
						
							| 
									
										
										
										
											2023-12-21 21:20:49 +08:00
										 |  |  |         img = cv2.resize(img, [max(srp), max(srp)]) | 
					
						
							|  |  |  |         srp = img.shape | 
					
						
							| 
									
										
										
										
											2023-12-21 19:52:19 +08:00
										 |  |  |     if scale != 1.0: | 
					
						
							|  |  |  |         img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)]) | 
					
						
							| 
									
										
										
										
											2023-12-21 21:20:49 +08:00
										 |  |  |         srp = img.shape | 
					
						
							| 
									
										
										
										
											2024-08-29 16:47:06 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     img = img * (-1) | 
					
						
							|  |  |  |     img = img + 255 | 
					
						
							|  |  |  |     img[img < 0] = 0 | 
					
						
							|  |  |  |     img = np.nan_to_num(img, nan=0.0) | 
					
						
							|  |  |  |     img[img > 255] = 255 | 
					
						
							|  |  |  |     imgs = img.astype(np.uint8) | 
					
						
							|  |  |  |     imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_JET) | 
					
						
							| 
									
										
										
										
											2024-09-16 18:46:09 +08:00
										 |  |  |     directory = Path(name).parent | 
					
						
							|  |  |  |     if not directory.is_dir(): | 
					
						
							|  |  |  |         directory.mkdir(parents=True, exist_ok=True) | 
					
						
							| 
									
										
										
										
											2024-08-29 16:47:06 +08:00
										 |  |  |     cv2.imwrite(name, imgs) | 
					
						
							| 
									
										
										
										
											2023-12-25 22:53:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def DumpTensorToLog(tensor, name="log"): | 
					
						
							|  |  |  |     shape = tensor.shape | 
					
						
							|  |  |  |     f = open(name, "w") | 
					
						
							| 
									
										
										
										
											2024-08-29 16:47:06 +08:00
										 |  |  |     data = tensor.reshape([-1]).float().cpu().detach().numpy().tolist() | 
					
						
							| 
									
										
										
										
											2023-12-25 22:53:53 +08:00
										 |  |  |     for d in data: | 
					
						
							|  |  |  |         f.writelines("%s" % d + os.linesep) | 
					
						
							|  |  |  |     f.close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-26 14:08:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-25 22:53:53 +08:00
										 |  |  | def DumpTensorToFile(tensor, name="tensor.pt"): | 
					
						
							| 
									
										
										
										
											2023-12-26 14:08:02 +08:00
										 |  |  |     torch.save(tensor.cpu(), name) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-25 22:53:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | def LoadTensorToFile(name="tensor.pt"): | 
					
						
							| 
									
										
										
										
											2023-12-26 14:08:02 +08:00
										 |  |  |     return torch.load(name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def DumpListToFile(list, name="list"): | 
					
						
							|  |  |  |     f = open(name, "w") | 
					
						
							|  |  |  |     for d in list: | 
					
						
							|  |  |  |         f.writelines("%s" % d + os.linesep) | 
					
						
							|  |  |  |     f.close() | 
					
						
							| 
									
										
										
										
											2023-12-29 19:55:53 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | prob_true = 0 | 
					
						
							|  |  |  | prob_all = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def ProbGE0(tensor: torch.tensor): | 
					
						
							|  |  |  |     global prob_true | 
					
						
							|  |  |  |     global prob_all | 
					
						
							|  |  |  |     m = tensor.ge(0) | 
					
						
							|  |  |  |     prob_true = prob_true + m.sum().item() | 
					
						
							|  |  |  |     prob_all = prob_all + math.prod(tensor.shape) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def DumpProb(): | 
					
						
							|  |  |  |     global prob_true | 
					
						
							|  |  |  |     global prob_all | 
					
						
							|  |  |  |     print("prob_true : " + str(prob_true)) | 
					
						
							|  |  |  |     print("prob_all : " + str(prob_all)) | 
					
						
							|  |  |  |     print("prob : " + str((prob_true * 100) / prob_all) + "%") |