Update test model.
This commit is contained in:
parent
ab65f98e39
commit
8882073978
|
@ -10,7 +10,7 @@ if __name__ == "__main__":
|
|||
|
||||
# checkpoint_path = "log/bigger/version_0/checkpoints/epoch=72-step=360328.ckpt"
|
||||
# checkpoint_path = "log/bigger/version_4/checkpoints/epoch=81-step=64288.ckpt"
|
||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
|
||||
checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt"
|
||||
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
|
@ -29,9 +29,10 @@ if __name__ == "__main__":
|
|||
# seq:849
|
||||
# seq:991
|
||||
# seq:995
|
||||
seq = 995
|
||||
|
||||
node = map.get_nodetree(995)
|
||||
item, l, rank_idx, rank_all = map.get_sequence(995)
|
||||
node = map.get_nodetree(seq)
|
||||
item, l, rank_idx, rank_all = map.get_sequence(seq)
|
||||
print("len of seq:" + str(len(item)))
|
||||
|
||||
for i in range(1, len(item)):
|
||||
|
|
|
@ -15,7 +15,7 @@ import dataset.dataset as ds
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=14-step=67455.ckpt"
|
||||
checkpoint_path = "log/bigger/version_6/checkpoints/epoch=14-step=67455.ckpt"
|
||||
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
|
|
Loading…
Reference in New Issue