From a7e0bd508cde3712c6ec23bb34f536a976f21ef8 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 26 Jun 2025 21:42:38 +0800 Subject: [PATCH] Update get_relation_distance --- tools/show.py | 4 +--- wit/dataset/meaning_dataset.py | 22 +++++++++++++++++++--- wit/query_block_output.py | 29 ++++++++++++++--------------- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/tools/show.py b/tools/show.py index e6bf0ff..7e76e35 100644 --- a/tools/show.py +++ b/tools/show.py @@ -72,13 +72,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None, img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)]) srp = img.shape - img = img * (-1) - img = img + 255 img[img < 0] = 0 img = np.nan_to_num(img, nan=0.0) img[img > 255] = 255 imgs = img.astype(np.uint8) - imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_JET) + imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_TURBO) directory = Path(name).parent if not directory.is_dir(): directory.mkdir(parents=True, exist_ok=True) diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index 6f04434..61f9d35 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -211,6 +211,8 @@ class MeaningMap: common_to_current[-1] = common_to_current[-1] + 1 level_change(ms_map, m, current_to_common, common_to_current) else: + common_to_current[-1] = common_to_current[-1] + 1 + current_to_common[-1] = current_to_common[-1] + 1 current_to_common.append(0) common_to_current.append(0) current_to_common[-2] = current_to_common[-2] + 1 @@ -225,14 +227,12 @@ class MeaningMap: return current_to_common, common_to_current # 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面token是否有关系 - def get_relation_table(self, meaning): + def get_relation_attention(self, meaning): current_to_common, common_to_current = self.get_level_change(meaning) width = len(current_to_common) relation = np.zeros((width, width), dtype=int) relation[0, 0] = 1 for i in range(1, width, 1): - if i == width - 2: - print(1) ori = current_to_common[i] - common_to_current[i] start_index = width for s in range(i - 1, -1, -1): @@ -243,6 +243,22 @@ class MeaningMap: relation[i, start_index : i + 1] = 1.0 return relation + # 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面每个token的距离 + def get_relation_distance(self, meaning): + current_to_common, common_to_current = self.get_level_change(meaning) + width = len(current_to_common) + relation = np.zeros((width, width), dtype=int) + relation[0, 0] = 0 + for i in range(1, width, 1): + com_dis = common_to_current[i] # 表示当前i到s的路径上最远公共点的距离 + direct_dis = common_to_current[i] # 表示i到s的相对距离 + for s in range(i - 1, -1, -1): + direct_dis = direct_dis - current_to_common[s] + com_dis = max(current_to_common[s] + direct_dis, com_dis) + relation[i, s] = com_dis + com_dis - direct_dis + direct_dis = direct_dis + common_to_current[s] + return relation + def max_length(self): return max(self.ms_len) diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 85039f5..d231ec3 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -25,7 +25,7 @@ if __name__ == "__main__": runner = ModelRunner(qwen.llm) def DumpQK(query, key, causal_mask, index): - global relation_table + global relation_distance size = query.shape[2] scale_factor = 1 / math.sqrt(query.size(-1)) attn_weight = query @ key.transpose(-2, -1) * scale_factor @@ -37,8 +37,8 @@ if __name__ == "__main__": qk = attn_weight[0] prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png" qk = qk.cpu() - qk = torch.cat((qk, relation_table.unsqueeze(0)), dim=0) - show.DumpTensorToImage(qk, prePath, GridValue=255) + qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0) + show.DumpTensorToImage(qk, prePath) # qk_seq.append(qk) # qk_index = size @@ -53,22 +53,21 @@ if __name__ == "__main__": # seq:991 # seq:995 meaning = 995 - node = map.get_nodetree(meaning) - current_to_common, common_to_current = map.get_level_change(meaning) + node = map.get_nodetree(meaning) node.print() - print(current_to_common) - print(common_to_current) - relation_table = map.get_relation_table(meaning) - # prePath = "./temp/" + "q@k_seq_" + "_layer_" + ".png" - # show.DumpTensorToImage(relation_table, prePath, GridValue=255) - relation_table = torch.tensor(relation_table) + + # current_to_common, common_to_current = map.get_level_change(meaning) + # print(current_to_common) + # print(common_to_current) + + relation_distance = map.get_relation_distance(meaning) + relation_distance = torch.tensor(relation_distance) + rd = relation_distance.unsqueeze(0) + rd = torch.cat((rd, rd, rd, rd, rd, rd), dim=0) + show.DumpTensorToImage(rd, "./temp/q@k_seq_layer.png") item, level, rank_idx, rank_all = map.get_sequence(meaning) - print(item) - print(level) - print(rank_idx) - print(rank_all) print("len of seq:" + str(len(item))) batch = torch.tensor([item], dtype=torch.int64)