Update show and q@k dump.
This commit is contained in:
		
							parent
							
								
									ae6ea67bbe
								
							
						
					
					
						commit
						17a2df2e6f
					
				|  | @ -45,9 +45,11 @@ class ResearchRunner(QwenRunner): | ||||||
|         scale_factor = 1 / math.sqrt(query.size(-1)) |         scale_factor = 1 / math.sqrt(query.size(-1)) | ||||||
|         attn_weight = query @ key.transpose(-2, -1) * scale_factor |         attn_weight = query @ key.transpose(-2, -1) * scale_factor | ||||||
|         attn_weight = torch.softmax(attn_weight, dim=-1) |         attn_weight = torch.softmax(attn_weight, dim=-1) | ||||||
| 
 |  | ||||||
|         size = query.shape[2] |         size = query.shape[2] | ||||||
|         qk = attn_weight[0] |         attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) | ||||||
|  |         attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) | ||||||
|  |         qk = attn_weight * attn_mask | ||||||
|  |         qk = qk[0] | ||||||
|         prePath = "./temp/" |         prePath = "./temp/" | ||||||
|         show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png") |         show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png") | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -8,27 +8,33 @@ import numpy as np | ||||||
| import os | import os | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0): | def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True): | ||||||
|     if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: |     if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: | ||||||
|         raise ("Error input dims") |         raise ("Error input dims") | ||||||
| 
 | 
 | ||||||
|     if len(tensor.shape) == 3: |     if len(tensor.shape) == 3: | ||||||
|         channel = tensor.shape[0] |         channel = tensor.shape[0] | ||||||
|         x = math.ceil((channel) ** 0.5) |         x = math.ceil((channel) ** 0.5) | ||||||
|         tensor = F.pad( |         tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0) | ||||||
|             tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0 |         if AutoContrast: | ||||||
|         ) |             calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2])) | ||||||
|         tensor = tensor.reshape((x, x, tensor.shape[1], tensor.shape[2])) |             tensormax = calc.max(1)[0] | ||||||
|  |             tensormin = calc.min(1)[0] | ||||||
|  |             calc = calc.transpose(1, 0) | ||||||
|  |             calc = ((calc - tensormin) / (tensormax - tensormin)) * 255 | ||||||
|  |             calc = calc.transpose(1, 0) | ||||||
|  |         tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2])) | ||||||
|         tensor = tensor.permute((0, 2, 1, 3)) |         tensor = tensor.permute((0, 2, 1, 3)) | ||||||
|         tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) |         tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) | ||||||
|         DumpTensorToImage(tensor, name, forceSquare=False, scale=scale) |         DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False) | ||||||
|         return |         return | ||||||
| 
 | 
 | ||||||
|     tensor = tensor.float() |     tensor = tensor.float() | ||||||
|     maxv = torch.max(tensor) |     if AutoContrast: | ||||||
|     minv = torch.min(tensor) |         maxv = torch.max(tensor) | ||||||
|     tensor = (((tensor - minv) / (maxv - minv)) * 255).byte().cpu() |         minv = torch.min(tensor) | ||||||
|     img = tensor.numpy() |         tensor = ((tensor - minv) / (maxv - minv)) * 255 | ||||||
|  |     img = tensor.byte().cpu().numpy() | ||||||
|     srp = img.shape |     srp = img.shape | ||||||
| 
 | 
 | ||||||
|     if len(srp) == 1:  # 1D的数据自动折叠成2D图像 |     if len(srp) == 1:  # 1D的数据自动折叠成2D图像 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue