Add dump tool.
This commit is contained in:
		
							parent
							
								
									a451def299
								
							
						
					
					
						commit
						68417fdc12
					
				|  | @ -502,13 +502,9 @@ class ChatGLMModel(nn.Module): | ||||||
|         # Rotary positional embeddings |         # Rotary positional embeddings | ||||||
|         rotary_pos_emb = self.rotary_pos_emb(self.seq_length) |         rotary_pos_emb = self.rotary_pos_emb(self.seq_length) | ||||||
| 
 | 
 | ||||||
|  |         from tools import show | ||||||
| 
 | 
 | ||||||
|         import plotly_express as px |         show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "plot.png", scale=0.1) | ||||||
|         img = px.imshow((rotary_pos_emb[:,:,0]*256).byte().cpu()) |  | ||||||
|         img.write_image("plot.png") |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
|         if position_ids is not None: |         if position_ids is not None: | ||||||
|             rotary_pos_emb = rotary_pos_emb[position_ids] |             rotary_pos_emb = rotary_pos_emb[position_ids] | ||||||
|  | @ -710,7 +706,7 @@ class ChatGLMForConditionalGeneration(nn.Module): | ||||||
|             pad_token_id=generation_config.pad_token_id, |             pad_token_id=generation_config.pad_token_id, | ||||||
|             eos_token_id=generation_config.eos_token_id, |             eos_token_id=generation_config.eos_token_id, | ||||||
|             output_hidden_states=generation_config.output_hidden_states, |             output_hidden_states=generation_config.output_hidden_states, | ||||||
|             use_cache = generation_config.use_cache |             use_cache=generation_config.use_cache, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1] |         outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1] | ||||||
|  | @ -724,7 +720,7 @@ class ChatGLMForConditionalGeneration(nn.Module): | ||||||
|         pad_token_id: Optional[int] = None, |         pad_token_id: Optional[int] = None, | ||||||
|         eos_token_id: Optional[Union[int, List[int]]] = None, |         eos_token_id: Optional[Union[int, List[int]]] = None, | ||||||
|         output_hidden_states: Optional[bool] = None, |         output_hidden_states: Optional[bool] = None, | ||||||
|         use_cache: Optional[bool] = None |         use_cache: Optional[bool] = None, | ||||||
|     ): |     ): | ||||||
|         if isinstance(eos_token_id, int): |         if isinstance(eos_token_id, int): | ||||||
|             eos_token_id = [eos_token_id] |             eos_token_id = [eos_token_id] | ||||||
|  |  | ||||||
							
								
								
									
										
											BIN
										
									
								
								plot.png
								
								
								
								
							
							
						
						
									
										
											BIN
										
									
								
								plot.png
								
								
								
								
							
										
											Binary file not shown.
										
									
								
							| Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 262 KiB | 
|  | @ -0,0 +1,52 @@ | ||||||
|  | import plotly_express as px | ||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | import torchvision.transforms.functional as Vision | ||||||
|  | import cv2 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0): | ||||||
|  |     if len(tensor.shape) != 2: | ||||||
|  |         raise ("Error input dims") | ||||||
|  |     tensor = tensor.float() | ||||||
|  |     maxv = torch.max(tensor) | ||||||
|  |     minv = torch.min(tensor) | ||||||
|  |     tensor = (((tensor - minv) / (maxv - minv)) * 256).byte().cpu() | ||||||
|  |     img = tensor.numpy() | ||||||
|  |     srp = img.shape | ||||||
|  |     if autoPad and (max(srp) / min(srp) > 16): | ||||||
|  |         img = cv2.resize(img,[max(srp),max(srp)]) | ||||||
|  |     srp = img.shape | ||||||
|  |     if scale != 1.0: | ||||||
|  |         img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)]) | ||||||
|  |     srp = img.shape | ||||||
|  |     cv2.imwrite(name, img) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # def DumpTensorToImage(tensor, name, autoPad=True, scale=1.0): | ||||||
|  | #     if len(tensor.shape) != 2: | ||||||
|  | #         raise ("Error input dims") | ||||||
|  | #     tensor = tensor.float() | ||||||
|  | #     maxv = torch.max(tensor) | ||||||
|  | #     minv = torch.min(tensor) | ||||||
|  | #     tensor = (((tensor - minv) / (maxv - minv)) * 256).byte().cpu() | ||||||
|  | #     srp = tensor.shape | ||||||
|  | #     if autoPad and (max(srp) / min(srp) > 16): | ||||||
|  | #         if srp[0] == min(srp): | ||||||
|  | #             tensor = F.pad(tensor, [max(srp) - min(srp), 0], "replicate") | ||||||
|  | #         else: | ||||||
|  | #             tensor = F.pad(tensor, [0, max(srp) - min(srp)], "replicate") | ||||||
|  | #     srp = tensor.shape | ||||||
|  | 
 | ||||||
|  | #     tensor = tensor.unsqueeze(0) | ||||||
|  | #     if scale != 1.0: | ||||||
|  | #         tensor = Vision.resize(tensor, [int(srp[0] * scale), int(srp[1] * scale)]) | ||||||
|  | #     tensor = tensor.view([int(srp[0] * scale), int(srp[1] * scale)]) | ||||||
|  | #     srp = tensor.shape | ||||||
|  | 
 | ||||||
|  | #     w = 1024 if max(srp) > 1024 else max(srp) | ||||||
|  | #     scale = max(srp) / w | ||||||
|  | #     # img = px.imshow(tensor) | ||||||
|  | #     # img.write_image(name) | ||||||
|  | #     cv2.imwrite(name, tensor.numpy()) | ||||||
|  | #     cv2.CreateMat(name, tensor.numpy()) | ||||||
|  | @ -0,0 +1,6 @@ | ||||||
|  | import show | ||||||
|  | import torch | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | radata = torch.randn(8192, 128) | ||||||
|  | show.DumpTensorToImage(radata, "test.png", autoPad=True,scale=0.2) | ||||||
		Loading…
	
		Reference in New Issue