Update show and q@k dump.
This commit is contained in:
parent
ae6ea67bbe
commit
17a2df2e6f
|
@ -45,9 +45,11 @@ class ResearchRunner(QwenRunner):
|
||||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||||
|
|
||||||
size = query.shape[2]
|
size = query.shape[2]
|
||||||
qk = attn_weight[0]
|
attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device)
|
||||||
|
attn_mask.masked_fill_(causal_mask.logical_not(), float(0))
|
||||||
|
qk = attn_weight * attn_mask
|
||||||
|
qk = qk[0]
|
||||||
prePath = "./temp/"
|
prePath = "./temp/"
|
||||||
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
|
show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png")
|
||||||
|
|
||||||
|
|
|
@ -8,27 +8,33 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0):
|
def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True):
|
||||||
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3:
|
||||||
raise ("Error input dims")
|
raise ("Error input dims")
|
||||||
|
|
||||||
if len(tensor.shape) == 3:
|
if len(tensor.shape) == 3:
|
||||||
channel = tensor.shape[0]
|
channel = tensor.shape[0]
|
||||||
x = math.ceil((channel) ** 0.5)
|
x = math.ceil((channel) ** 0.5)
|
||||||
tensor = F.pad(
|
tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0)
|
||||||
tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0
|
if AutoContrast:
|
||||||
)
|
calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2]))
|
||||||
tensor = tensor.reshape((x, x, tensor.shape[1], tensor.shape[2]))
|
tensormax = calc.max(1)[0]
|
||||||
|
tensormin = calc.min(1)[0]
|
||||||
|
calc = calc.transpose(1, 0)
|
||||||
|
calc = ((calc - tensormin) / (tensormax - tensormin)) * 255
|
||||||
|
calc = calc.transpose(1, 0)
|
||||||
|
tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2]))
|
||||||
tensor = tensor.permute((0, 2, 1, 3))
|
tensor = tensor.permute((0, 2, 1, 3))
|
||||||
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
|
tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3]))
|
||||||
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale)
|
DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False)
|
||||||
return
|
return
|
||||||
|
|
||||||
tensor = tensor.float()
|
tensor = tensor.float()
|
||||||
maxv = torch.max(tensor)
|
if AutoContrast:
|
||||||
minv = torch.min(tensor)
|
maxv = torch.max(tensor)
|
||||||
tensor = (((tensor - minv) / (maxv - minv)) * 255).byte().cpu()
|
minv = torch.min(tensor)
|
||||||
img = tensor.numpy()
|
tensor = ((tensor - minv) / (maxv - minv)) * 255
|
||||||
|
img = tensor.byte().cpu().numpy()
|
||||||
srp = img.shape
|
srp = img.shape
|
||||||
|
|
||||||
if len(srp) == 1: # 1D的数据自动折叠成2D图像
|
if len(srp) == 1: # 1D的数据自动折叠成2D图像
|
||||||
|
|
Loading…
Reference in New Issue