Update get_relation_distance

This commit is contained in:
Colin 2025-06-26 21:42:38 +08:00
parent 927c98e823
commit a7e0bd508c
3 changed files with 34 additions and 21 deletions

View File

@ -72,13 +72,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
srp = img.shape
img = img * (-1)
img = img + 255
img[img < 0] = 0
img = np.nan_to_num(img, nan=0.0)
img[img > 255] = 255
imgs = img.astype(np.uint8)
imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_JET)
imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_TURBO)
directory = Path(name).parent
if not directory.is_dir():
directory.mkdir(parents=True, exist_ok=True)

View File

@ -211,6 +211,8 @@ class MeaningMap:
common_to_current[-1] = common_to_current[-1] + 1
level_change(ms_map, m, current_to_common, common_to_current)
else:
common_to_current[-1] = common_to_current[-1] + 1
current_to_common[-1] = current_to_common[-1] + 1
current_to_common.append(0)
common_to_current.append(0)
current_to_common[-2] = current_to_common[-2] + 1
@ -225,14 +227,12 @@ class MeaningMap:
return current_to_common, common_to_current
# 根据meaning的层级结构范围一个二位的数组表示所有token跟前面token是否有关系
def get_relation_table(self, meaning):
def get_relation_attention(self, meaning):
current_to_common, common_to_current = self.get_level_change(meaning)
width = len(current_to_common)
relation = np.zeros((width, width), dtype=int)
relation[0, 0] = 1
for i in range(1, width, 1):
if i == width - 2:
print(1)
ori = current_to_common[i] - common_to_current[i]
start_index = width
for s in range(i - 1, -1, -1):
@ -243,6 +243,22 @@ class MeaningMap:
relation[i, start_index : i + 1] = 1.0
return relation
# 根据meaning的层级结构范围一个二位的数组表示所有token跟前面每个token的距离
def get_relation_distance(self, meaning):
current_to_common, common_to_current = self.get_level_change(meaning)
width = len(current_to_common)
relation = np.zeros((width, width), dtype=int)
relation[0, 0] = 0
for i in range(1, width, 1):
com_dis = common_to_current[i] # 表示当前i到s的路径上最远公共点的距离
direct_dis = common_to_current[i] # 表示i到s的相对距离
for s in range(i - 1, -1, -1):
direct_dis = direct_dis - current_to_common[s]
com_dis = max(current_to_common[s] + direct_dis, com_dis)
relation[i, s] = com_dis + com_dis - direct_dis
direct_dis = direct_dis + common_to_current[s]
return relation
def max_length(self):
return max(self.ms_len)

View File

@ -25,7 +25,7 @@ if __name__ == "__main__":
runner = ModelRunner(qwen.llm)
def DumpQK(query, key, causal_mask, index):
global relation_table
global relation_distance
size = query.shape[2]
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
@ -37,8 +37,8 @@ if __name__ == "__main__":
qk = attn_weight[0]
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
qk = qk.cpu()
qk = torch.cat((qk, relation_table.unsqueeze(0)), dim=0)
show.DumpTensorToImage(qk, prePath, GridValue=255)
qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0)
show.DumpTensorToImage(qk, prePath)
# qk_seq.append(qk)
# qk_index = size
@ -53,22 +53,21 @@ if __name__ == "__main__":
# seq:991
# seq:995
meaning = 995
node = map.get_nodetree(meaning)
current_to_common, common_to_current = map.get_level_change(meaning)
node = map.get_nodetree(meaning)
node.print()
print(current_to_common)
print(common_to_current)
relation_table = map.get_relation_table(meaning)
# prePath = "./temp/" + "q@k_seq_" + "_layer_" + ".png"
# show.DumpTensorToImage(relation_table, prePath, GridValue=255)
relation_table = torch.tensor(relation_table)
# current_to_common, common_to_current = map.get_level_change(meaning)
# print(current_to_common)
# print(common_to_current)
relation_distance = map.get_relation_distance(meaning)
relation_distance = torch.tensor(relation_distance)
rd = relation_distance.unsqueeze(0)
rd = torch.cat((rd, rd, rd, rd, rd, rd), dim=0)
show.DumpTensorToImage(rd, "./temp/q@k_seq_layer.png")
item, level, rank_idx, rank_all = map.get_sequence(meaning)
print(item)
print(level)
print(rank_idx)
print(rank_all)
print("len of seq:" + str(len(item)))
batch = torch.tensor([item], dtype=torch.int64)