diff --git a/qwen/research_attention.py b/qwen/research_attention.py index 3579459..3ab98c3 100644 --- a/qwen/research_attention.py +++ b/qwen/research_attention.py @@ -45,9 +45,11 @@ class ResearchRunner(QwenRunner): scale_factor = 1 / math.sqrt(query.size(-1)) attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight = torch.softmax(attn_weight, dim=-1) - size = query.shape[2] - qk = attn_weight[0] + attn_mask = torch.ones(causal_mask.shape, dtype=query.dtype, device=query.device) + attn_mask.masked_fill_(causal_mask.logical_not(), float(0)) + qk = attn_weight * attn_mask + qk = qk[0] prePath = "./temp/" show.DumpTensorToImage(qk, prePath + "q@k_seq_" + str(size) + "_layer_" + str(attention.index) + ".png") diff --git a/tools/show.py b/tools/show.py index 3eb4183..f83d662 100644 --- a/tools/show.py +++ b/tools/show.py @@ -8,27 +8,33 @@ import numpy as np import os -def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0): +def DumpTensorToImage(tensor, name, forceSquare=True, scale=1.0, AutoContrast=True): if len(tensor.shape) != 2 and len(tensor.shape) != 1 and len(tensor.shape) != 3: raise ("Error input dims") if len(tensor.shape) == 3: channel = tensor.shape[0] x = math.ceil((channel) ** 0.5) - tensor = F.pad( - tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0 - ) - tensor = tensor.reshape((x, x, tensor.shape[1], tensor.shape[2])) + tensor = F.pad(tensor, (0, 1, 0, 1, 0, x * x - channel), mode="constant", value=0) + if AutoContrast: + calc = tensor.reshape((x * x, tensor.shape[1] * tensor.shape[2])) + tensormax = calc.max(1)[0] + tensormin = calc.min(1)[0] + calc = calc.transpose(1, 0) + calc = ((calc - tensormin) / (tensormax - tensormin)) * 255 + calc = calc.transpose(1, 0) + tensor = calc.reshape((x, x, tensor.shape[1], tensor.shape[2])) tensor = tensor.permute((0, 2, 1, 3)) tensor = tensor.reshape((x * tensor.shape[1], x * tensor.shape[3])) - DumpTensorToImage(tensor, name, forceSquare=False, scale=scale) + DumpTensorToImage(tensor, name, forceSquare=False, scale=scale, AutoContrast=False) return tensor = tensor.float() - maxv = torch.max(tensor) - minv = torch.min(tensor) - tensor = (((tensor - minv) / (maxv - minv)) * 255).byte().cpu() - img = tensor.numpy() + if AutoContrast: + maxv = torch.max(tensor) + minv = torch.min(tensor) + tensor = ((tensor - minv) / (maxv - minv)) * 255 + img = tensor.byte().cpu().numpy() srp = img.shape if len(srp) == 1: # 1D的数据自动折叠成2D图像