Refine query.
This commit is contained in:
parent
bd5379f24e
commit
e048d0c531
|
@ -0,0 +1,161 @@
|
|||
len of seq:41
|
||||
index: 1 golden: 6 -> 15 ERR
|
||||
index: 12 golden: 10 -> 22 ERR
|
||||
995
|
||||
├── 202
|
||||
│ ├── <0> 12
|
||||
│ ├── 60
|
||||
│ │ ├── <1> 6 ERR_15
|
||||
│ │ ├── <2> 11
|
||||
│ │ ├── <3> 8
|
||||
│ │ ├── <4> 9
|
||||
│ │ └── <5> 16
|
||||
│ ├── 53
|
||||
│ │ ├── <6> 3
|
||||
│ │ ├── <7> 5
|
||||
│ │ ├── <8> 11
|
||||
│ │ ├── <9> 14
|
||||
│ │ └── <10> 13
|
||||
│ └── <11> 16
|
||||
├── 51
|
||||
│ ├── <12> 10 ERR_22
|
||||
│ ├── <13> 6
|
||||
│ └── <14> 16
|
||||
├── 205
|
||||
│ ├── 76
|
||||
│ │ ├── <15> 0
|
||||
│ │ ├── <16> 25
|
||||
│ │ ├── <17> 3
|
||||
│ │ ├── <18> 6
|
||||
│ │ └── <19> 6
|
||||
│ ├── <20> 19
|
||||
│ ├── <21> 18
|
||||
│ ├── <22> 2
|
||||
│ ├── 55
|
||||
│ │ ├── <23> 14
|
||||
│ │ ├── 31
|
||||
│ │ │ ├── <24> 8
|
||||
│ │ │ ├── <25> 5
|
||||
│ │ │ ├── <26> 5
|
||||
│ │ │ └── <27> 2
|
||||
│ │ ├── <28> 0
|
||||
│ │ ├── <29> 4
|
||||
│ │ ├── <30> 1
|
||||
│ │ └── <31> 1
|
||||
│ └── 31
|
||||
│ ├── <32> 8
|
||||
│ ├── <33> 5
|
||||
│ ├── <34> 5
|
||||
│ └── <35> 2
|
||||
└── 217
|
||||
├── <36> 3
|
||||
└── 31
|
||||
├── <37> 8
|
||||
├── <38> 5
|
||||
├── <39> 5
|
||||
└── <40> 2
|
||||
|
||||
Mapping Load from disk cache: structured_language_50000_32_1_6_2_1_False_42
|
||||
Mapping Load end, elapsed:0.27529025077819824s
|
||||
len of seq:41
|
||||
|
||||
|
||||
index: 2 golden: 6 -> 12 ERR
|
||||
995
|
||||
├── 202
|
||||
│ ├── <0> 12
|
||||
│ ├── <1> 12
|
||||
│ ├── 60
|
||||
│ │ ├── <2> 6 ERR_12
|
||||
│ │ ├── <3> 6
|
||||
│ │ ├── <4> 11
|
||||
│ │ ├── <5> 11
|
||||
│ │ ├── <6> 8
|
||||
│ │ ├── <7> 8
|
||||
│ │ ├── <8> 9
|
||||
│ │ ├── <9> 9
|
||||
│ │ ├── <10> 16
|
||||
│ │ └── <11> 16
|
||||
│ ├── 53
|
||||
│ │ ├── <12> 3
|
||||
│ │ ├── <13> 3
|
||||
│ │ ├── <14> 5
|
||||
│ │ ├── <15> 5
|
||||
│ │ ├── <16> 11
|
||||
│ │ ├── <17> 11
|
||||
│ │ ├── <18> 14
|
||||
│ │ ├── <19> 14
|
||||
│ │ ├── <20> 13
|
||||
│ │ └── <21> 13
|
||||
│ ├── <22> 16
|
||||
│ └── <23> 16
|
||||
├── 51
|
||||
│ ├── <24> 10
|
||||
│ ├── <25> 10
|
||||
│ ├── <26> 6
|
||||
│ ├── <27> 6
|
||||
│ ├── <28> 16
|
||||
│ └── <29> 16
|
||||
├── 205
|
||||
│ ├── 76
|
||||
│ │ ├── <30> 0
|
||||
│ │ ├── <31> 0
|
||||
│ │ ├── <32> 25
|
||||
│ │ ├── <33> 25
|
||||
│ │ ├── <34> 3
|
||||
│ │ ├── <35> 3
|
||||
│ │ ├── <36> 6
|
||||
│ │ ├── <37> 6
|
||||
│ │ ├── <38> 6
|
||||
│ │ └── <39> 6
|
||||
│ ├── <40> 19
|
||||
│ ├── <41> 19
|
||||
│ ├── <42> 18
|
||||
│ ├── <43> 18
|
||||
│ ├── <44> 2
|
||||
│ ├── <45> 2
|
||||
│ ├── 55
|
||||
│ │ ├── <46> 14
|
||||
│ │ ├── <47> 14
|
||||
│ │ ├── 31
|
||||
│ │ │ ├── <48> 8
|
||||
│ │ │ ├── <49> 8
|
||||
│ │ │ ├── <50> 5
|
||||
│ │ │ ├── <51> 5
|
||||
│ │ │ ├── <52> 5
|
||||
│ │ │ ├── <53> 5
|
||||
│ │ │ ├── <54> 2
|
||||
│ │ │ └── <55> 2
|
||||
│ │ ├── <56> 0
|
||||
│ │ ├── <57> 0
|
||||
│ │ ├── <58> 4
|
||||
│ │ ├── <59> 4
|
||||
│ │ ├── <60> 1
|
||||
│ │ ├── <61> 1
|
||||
│ │ ├── <62> 1
|
||||
│ │ └── <63> 1
|
||||
│ └── 31
|
||||
│ ├── <64> 8
|
||||
│ ├── <65> 8
|
||||
│ ├── <66> 5
|
||||
│ ├── <67> 5
|
||||
│ ├── <68> 5
|
||||
│ ├── <69> 5
|
||||
│ ├── <70> 2
|
||||
│ └── <71> 2
|
||||
└── 217
|
||||
├── <72> 3
|
||||
├── <73> 3
|
||||
└── 31
|
||||
├── <74> 8
|
||||
├── <75> 8
|
||||
├── <76> 5
|
||||
├── <77> 5
|
||||
├── <78> 5
|
||||
├── <79> 5
|
||||
├── <80> 2
|
||||
└── <81> 2
|
||||
|
||||
Mapping Load from disk cache: structured_language_50000_32_1_6_2_2_False_42
|
||||
Mapping Load end, elapsed:0.2843344211578369s
|
||||
len of seq:82
|
|
@ -223,6 +223,17 @@ class MeaningMap:
|
|||
self.ms_rank_all[start : start + len],
|
||||
)
|
||||
|
||||
def get_maps(self, contain=-1):
|
||||
seqlist = []
|
||||
start = self.ms_start[contain]
|
||||
seq = self.ms_data[start : start + self.ms_len[contain]]
|
||||
level = self.ms_level[start : start + self.ms_len[contain]]
|
||||
ms = self.ms_map
|
||||
for m in range(contain, len(ms)):
|
||||
if contain in ms[m]:
|
||||
seqlist.append(ms[m])
|
||||
return seqlist
|
||||
|
||||
def get_nodetree(self, meaning): # return meaning all sub items
|
||||
def get_tree_node(self, meaning, seq, level, vs, parent, seqlist, index):
|
||||
ms = self.ms_map[meaning]
|
||||
|
|
|
@ -46,11 +46,23 @@ def get_dataset_set_freq(dataset):
|
|||
print(f"meaning of {m} count as {total}")
|
||||
|
||||
|
||||
def get_inference(dataset, seq):
|
||||
def get_inference_single(dataset, meaning, index):
|
||||
map = dataset.get_meaning_map()
|
||||
|
||||
node = map.get_nodetree(seq)
|
||||
item, l, rank_idx, rank_all = map.get_sequence(seq)
|
||||
item, l, rank_idx, rank_all = map.get_sequence(meaning)
|
||||
print("len of seq:" + str(len(item)))
|
||||
|
||||
batch = torch.tensor([item[:index]], dtype=torch.int64)
|
||||
sorted_logits, sorted_indices = runner.ChatTokens(batch, sample=False)
|
||||
next_token = sorted_indices.detach().cpu().numpy()[0][0]
|
||||
return next_token
|
||||
|
||||
|
||||
def get_inference(dataset, meaning):
|
||||
map = dataset.get_meaning_map()
|
||||
|
||||
node = map.get_nodetree(meaning)
|
||||
item, l, rank_idx, rank_all = map.get_sequence(meaning)
|
||||
print("len of seq:" + str(len(item)))
|
||||
|
||||
for i in range(1, len(item)):
|
||||
|
@ -64,38 +76,14 @@ def get_inference(dataset, seq):
|
|||
node.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def get_contain_meaning(dataset, meaning):
|
||||
map = dataset.get_meaning_map()
|
||||
seqs = map.get_maps(contain=meaning)
|
||||
for seq in seqs:
|
||||
print(str(seq))
|
||||
|
||||
log_path = "log/bigger/version_1/"
|
||||
|
||||
file = get_latest_file_safe(log_path + "/checkpoints")
|
||||
checkpoint_path = log_path + "checkpoints/" + file
|
||||
conf = configuration.class_from_file(log_path + "conf.pkl")
|
||||
model = QWenLMHeadModel(conf.model_config)
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
|
||||
qwen.eval()
|
||||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
train, val = ds.InitDataset(conf)
|
||||
val = val.dataset
|
||||
# get_dataset_set_freq(train.dataset)
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
|
||||
# seq:844
|
||||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
meaning = 991
|
||||
meaning = 10991
|
||||
|
||||
node = map.get_nodetree(meaning)
|
||||
node.print()
|
||||
|
||||
get_inference(md, meaning)
|
||||
def dump_dk(dataset, meaning):
|
||||
|
||||
def DumpQK(query, key, causal_mask, index):
|
||||
global relation_distance
|
||||
|
@ -121,6 +109,8 @@ if __name__ == "__main__":
|
|||
# print(current_to_common)
|
||||
# print(common_to_current)
|
||||
|
||||
map = dataset.get_meaning_map()
|
||||
|
||||
relation_distance = map.get_relation_distance(meaning)
|
||||
relation_distance = torch.tensor(relation_distance)
|
||||
rd = relation_distance.unsqueeze(0)
|
||||
|
@ -139,3 +129,44 @@ if __name__ == "__main__":
|
|||
|
||||
# print(sorted_logits.detach().cpu().numpy())
|
||||
# print(sorted_indices.detach().cpu().numpy())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
log_path = "log/bigger/version_0/"
|
||||
|
||||
file = get_latest_file_safe(log_path + "/checkpoints")
|
||||
checkpoint_path = log_path + "checkpoints/" + file
|
||||
conf = configuration.class_from_file(log_path + "conf.pkl")
|
||||
model = QWenLMHeadModel(conf.model_config)
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path, model=model)
|
||||
qwen.eval()
|
||||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
np.random.seed(conf.seed)
|
||||
runner = ModelRunner(qwen.llm)
|
||||
|
||||
train, val = ds.InitDataset(conf)
|
||||
val = val.dataset
|
||||
# get_dataset_set_freq(train.dataset)
|
||||
md = val.meaning_dataset
|
||||
map = md.get_meaning_map()
|
||||
|
||||
# seq:844
|
||||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
meaning = 995
|
||||
meaning = 2995
|
||||
|
||||
# node = map.get_nodetree(meaning)
|
||||
# node.print()
|
||||
|
||||
to = get_inference_single(md, meaning, 12)
|
||||
print(to)
|
||||
|
||||
get_inference(md, meaning)
|
||||
|
||||
get_contain_meaning(md, meaning)
|
||||
|
||||
dump_dk(md, meaning)
|
||||
|
|
Loading…
Reference in New Issue