Update get_relation_distance
This commit is contained in:
parent
927c98e823
commit
a7e0bd508c
|
@ -72,13 +72,11 @@ def DumpTensorToImage(tensor, name, forceSquare=False, scale=1.0, Contrast=None,
|
||||||
img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
|
img = cv2.resize(img, [int(srp[0] * scale), int(srp[1] * scale)])
|
||||||
srp = img.shape
|
srp = img.shape
|
||||||
|
|
||||||
img = img * (-1)
|
|
||||||
img = img + 255
|
|
||||||
img[img < 0] = 0
|
img[img < 0] = 0
|
||||||
img = np.nan_to_num(img, nan=0.0)
|
img = np.nan_to_num(img, nan=0.0)
|
||||||
img[img > 255] = 255
|
img[img > 255] = 255
|
||||||
imgs = img.astype(np.uint8)
|
imgs = img.astype(np.uint8)
|
||||||
imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_JET)
|
imgs = cv2.applyColorMap(imgs, cv2.COLORMAP_TURBO)
|
||||||
directory = Path(name).parent
|
directory = Path(name).parent
|
||||||
if not directory.is_dir():
|
if not directory.is_dir():
|
||||||
directory.mkdir(parents=True, exist_ok=True)
|
directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
|
@ -211,6 +211,8 @@ class MeaningMap:
|
||||||
common_to_current[-1] = common_to_current[-1] + 1
|
common_to_current[-1] = common_to_current[-1] + 1
|
||||||
level_change(ms_map, m, current_to_common, common_to_current)
|
level_change(ms_map, m, current_to_common, common_to_current)
|
||||||
else:
|
else:
|
||||||
|
common_to_current[-1] = common_to_current[-1] + 1
|
||||||
|
current_to_common[-1] = current_to_common[-1] + 1
|
||||||
current_to_common.append(0)
|
current_to_common.append(0)
|
||||||
common_to_current.append(0)
|
common_to_current.append(0)
|
||||||
current_to_common[-2] = current_to_common[-2] + 1
|
current_to_common[-2] = current_to_common[-2] + 1
|
||||||
|
@ -225,14 +227,12 @@ class MeaningMap:
|
||||||
return current_to_common, common_to_current
|
return current_to_common, common_to_current
|
||||||
|
|
||||||
# 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面token是否有关系
|
# 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面token是否有关系
|
||||||
def get_relation_table(self, meaning):
|
def get_relation_attention(self, meaning):
|
||||||
current_to_common, common_to_current = self.get_level_change(meaning)
|
current_to_common, common_to_current = self.get_level_change(meaning)
|
||||||
width = len(current_to_common)
|
width = len(current_to_common)
|
||||||
relation = np.zeros((width, width), dtype=int)
|
relation = np.zeros((width, width), dtype=int)
|
||||||
relation[0, 0] = 1
|
relation[0, 0] = 1
|
||||||
for i in range(1, width, 1):
|
for i in range(1, width, 1):
|
||||||
if i == width - 2:
|
|
||||||
print(1)
|
|
||||||
ori = current_to_common[i] - common_to_current[i]
|
ori = current_to_common[i] - common_to_current[i]
|
||||||
start_index = width
|
start_index = width
|
||||||
for s in range(i - 1, -1, -1):
|
for s in range(i - 1, -1, -1):
|
||||||
|
@ -243,6 +243,22 @@ class MeaningMap:
|
||||||
relation[i, start_index : i + 1] = 1.0
|
relation[i, start_index : i + 1] = 1.0
|
||||||
return relation
|
return relation
|
||||||
|
|
||||||
|
# 根据meaning的层级结构范围一个二位的数组,表示所有token跟前面每个token的距离
|
||||||
|
def get_relation_distance(self, meaning):
|
||||||
|
current_to_common, common_to_current = self.get_level_change(meaning)
|
||||||
|
width = len(current_to_common)
|
||||||
|
relation = np.zeros((width, width), dtype=int)
|
||||||
|
relation[0, 0] = 0
|
||||||
|
for i in range(1, width, 1):
|
||||||
|
com_dis = common_to_current[i] # 表示当前i到s的路径上最远公共点的距离
|
||||||
|
direct_dis = common_to_current[i] # 表示i到s的相对距离
|
||||||
|
for s in range(i - 1, -1, -1):
|
||||||
|
direct_dis = direct_dis - current_to_common[s]
|
||||||
|
com_dis = max(current_to_common[s] + direct_dis, com_dis)
|
||||||
|
relation[i, s] = com_dis + com_dis - direct_dis
|
||||||
|
direct_dis = direct_dis + common_to_current[s]
|
||||||
|
return relation
|
||||||
|
|
||||||
def max_length(self):
|
def max_length(self):
|
||||||
return max(self.ms_len)
|
return max(self.ms_len)
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ if __name__ == "__main__":
|
||||||
runner = ModelRunner(qwen.llm)
|
runner = ModelRunner(qwen.llm)
|
||||||
|
|
||||||
def DumpQK(query, key, causal_mask, index):
|
def DumpQK(query, key, causal_mask, index):
|
||||||
global relation_table
|
global relation_distance
|
||||||
size = query.shape[2]
|
size = query.shape[2]
|
||||||
scale_factor = 1 / math.sqrt(query.size(-1))
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
|
@ -37,8 +37,8 @@ if __name__ == "__main__":
|
||||||
qk = attn_weight[0]
|
qk = attn_weight[0]
|
||||||
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
prePath = "./temp/" + "q@k_seq_" + str(size) + "_layer_" + str(index) + ".png"
|
||||||
qk = qk.cpu()
|
qk = qk.cpu()
|
||||||
qk = torch.cat((qk, relation_table.unsqueeze(0)), dim=0)
|
qk = torch.cat((qk, relation_distance.unsqueeze(0)), dim=0)
|
||||||
show.DumpTensorToImage(qk, prePath, GridValue=255)
|
show.DumpTensorToImage(qk, prePath)
|
||||||
# qk_seq.append(qk)
|
# qk_seq.append(qk)
|
||||||
# qk_index = size
|
# qk_index = size
|
||||||
|
|
||||||
|
@ -53,22 +53,21 @@ if __name__ == "__main__":
|
||||||
# seq:991
|
# seq:991
|
||||||
# seq:995
|
# seq:995
|
||||||
meaning = 995
|
meaning = 995
|
||||||
node = map.get_nodetree(meaning)
|
|
||||||
current_to_common, common_to_current = map.get_level_change(meaning)
|
|
||||||
|
|
||||||
|
node = map.get_nodetree(meaning)
|
||||||
node.print()
|
node.print()
|
||||||
print(current_to_common)
|
|
||||||
print(common_to_current)
|
# current_to_common, common_to_current = map.get_level_change(meaning)
|
||||||
relation_table = map.get_relation_table(meaning)
|
# print(current_to_common)
|
||||||
# prePath = "./temp/" + "q@k_seq_" + "_layer_" + ".png"
|
# print(common_to_current)
|
||||||
# show.DumpTensorToImage(relation_table, prePath, GridValue=255)
|
|
||||||
relation_table = torch.tensor(relation_table)
|
relation_distance = map.get_relation_distance(meaning)
|
||||||
|
relation_distance = torch.tensor(relation_distance)
|
||||||
|
rd = relation_distance.unsqueeze(0)
|
||||||
|
rd = torch.cat((rd, rd, rd, rd, rd, rd), dim=0)
|
||||||
|
show.DumpTensorToImage(rd, "./temp/q@k_seq_layer.png")
|
||||||
|
|
||||||
item, level, rank_idx, rank_all = map.get_sequence(meaning)
|
item, level, rank_idx, rank_all = map.get_sequence(meaning)
|
||||||
print(item)
|
|
||||||
print(level)
|
|
||||||
print(rank_idx)
|
|
||||||
print(rank_all)
|
|
||||||
print("len of seq:" + str(len(item)))
|
print("len of seq:" + str(len(item)))
|
||||||
|
|
||||||
batch = torch.tensor([item], dtype=torch.int64)
|
batch = torch.tensor([item], dtype=torch.int64)
|
||||||
|
|
Loading…
Reference in New Issue