Update get_relation_distance
This commit is contained in:
parent
927c98e823
commit
a7e0bd508c
|
@ -72,13 +72,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
|
|||
img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
|
||||
srp = img.shape
|
||||
|
||||
img = img * (-1)
|
||||
img = img + 255
|
||||
img[img < 0] = 0
|
||||
img = np.nan_to_num(img, nan=0.0)
|
||||
img[img > 255] = 255
|
||||
imgs = img.astype(np.uint8)
|
||||
imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_JET)
|
||||
imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_TURBO)
|
||||
directory = Path(name).parent
|
||||
if not directory.is_dir():
|
||||
directory.mkdir(parents=True, exist_ok=True)
|
||||
|
|
|
@ -211,6 +211,8 @@ class MeaningMap:
|
|||
common_to_current[-1] = common_to_current[-1] + 1
|
||||
level_change(ms_map, m, current_to_common, common_to_current)
|
||||
else:
|
||||
common_to_current[-1] = common_to_current[-1] + 1
|
||||
current_to_common[-1] = current_to_common[-1] + 1
|
||||
current_to_common.append(0)
|
||||
common_to_current.append(0)
|
||||
current_to_common[-2] = current_to_common[-2] + 1
|
||||
|
@ -225,14 +227,12 @@ class MeaningMap:
|
|||
return current_to_common, common_to_current
|
||||
|
||||
# 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面token是否有关系
|
||||
def get_relation_table(self, meaning):
|
||||
def get_relation_attention(self, meaning):
|
||||
current_to_common, common_to_current = self.get_level_change(meaning)
|
||||
width = len(current_to_common)
|
||||
relation = np.zeros((width, width), dtype=int)
|
||||
relation[0, 0] = 1
|
||||
for i in range(1, width, 1):
|
||||
if i == width - 2:
|
||||
print(1)
|
||||
ori = current_to_common[i] - common_to_current[i]
|
||||
start_index = width
|
||||
for s in range(i - 1, -1, -1):
|
||||
|
@ -243,6 +243,22 @@ class MeaningMap:
|
|||
relation[i, start_index : i + 1] = 1.0
|
||||
return relation
|
||||
|
||||
# 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面每个token的距离
|
||||
def get_relation_distance(self, meaning):
|
||||
current_to_common, common_to_current = self.get_level_change(meaning)
|
||||
width = len(current_to_common)
|
||||
relation = np.zeros((width, width), dtype=int)
|
||||
relation[0, 0] = 0
|
||||
for i in range(1, width, 1):
|
||||
com_dis = common_to_current[i] # 表示当前i到s的路径上最远公共点的距离
|
||||
direct_dis = common_to_current[i] # 表示i到s的相对距离
|
||||
for s in range(i - 1, -1, -1):
|
||||
direct_dis = direct_dis - current_to_common[s]
|
||||
com_dis = max(current_to_common[s] + direct_dis, com_dis)
|
||||
relation[i, s] = com_dis + com_dis - direct_dis
|
||||
direct_dis = direct_dis + common_to_current[s]
|
||||
return relation
|
||||
|
||||
def max_length(self):
|
||||
return max(self.ms_len)
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ if __name__ == "__main__":
|
|||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
def DumpQK(query, key, causal_mask, index):
|
||||
global relation_table
|
||||
global relation_distance
|
||||
size = query.shape[2]
|
||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
|
@ -37,8 +37,8 @@ if __name__ == "__main__":
|
|||
qk = attn_weight[0]
|
||||
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
||||
qk = qk.cpu()
|
||||
qk = torch.cat((qk, relation_table.unsqueeze(0)), dim=0)
|
||||
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
||||
qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0)
|
||||
show.DumpTensorToImage(qk, prePath)
|
||||
# qk_seq.append(qk)
|
||||
# qk_index = size
|
||||
|
||||
|
@ -53,22 +53,21 @@ if __name__ == "__main__":
|
|||
# seq:991
|
||||
# seq:995
|
||||
meaning = 995
|
||||
node = map.get_nodetree(meaning)
|
||||
current_to_common, common_to_current = map.get_level_change(meaning)
|
||||
|
||||
node = map.get_nodetree(meaning)
|
||||
node.print()
|
||||
print(current_to_common)
|
||||
print(common_to_current)
|
||||
relation_table = map.get_relation_table(meaning)
|
||||
# prePath = "./temp/" + "q@k_seq_" + "_layer_" + ".png"
|
||||
# show.DumpTensorToImage(relation_table, prePath, GridValue=255)
|
||||
relation_table = torch.tensor(relation_table)
|
||||
|
||||
# current_to_common, common_to_current = map.get_level_change(meaning)
|
||||
# print(current_to_common)
|
||||
# print(common_to_current)
|
||||
|
||||
relation_distance = map.get_relation_distance(meaning)
|
||||
relation_distance = torch.tensor(relation_distance)
|
||||
rd = relation_distance.unsqueeze(0)
|
||||
rd = torch.cat((rd, rd, rd, rd, rd, rd), dim=0)
|
||||
show.DumpTensorToImage(rd, "./temp/q@k_seq_layer.png")
|
||||
|
||||
item, level, rank_idx, rank_all = map.get_sequence(meaning)
|
||||
print(item)
|
||||
print(level)
|
||||
print(rank_idx)
|
||||
print(rank_all)
|
||||
print("len of seq:" + str(len(item)))
|
||||
|
||||
batch = torch.tensor([item], dtype=torch.int64)
|
||||
|
|
Loading…
Reference in New Issue