diff --git a/wit/inference.py b/wit/inference.py index ffba9df..b932075 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -1,7 +1,7 @@ import torch -from wit.model.light_module import LightModule -from wit.model.light_module import ModelRunner +from model.light_module import LightModule +from model.light_module import ModelRunner import numpy as np import dataset.dataset as ds @@ -12,7 +12,7 @@ if __name__ == "__main__": checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt" checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" - # checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt" + # checkpoint_path = "log/bigger/version_6/checkpoints/epoch=43-step=217184.ckpt" qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 3c6398d..6d0c9e0 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -1,7 +1,7 @@ import torch -from wit.model.light_module import LightModule -from wit.model.light_module import ModelRunner +from model.light_module import LightModule +from model.light_module import ModelRunner import numpy as np import math diff --git a/wit/query_meaning_freq.py b/wit/query_meaning_freq.py index 0067157..6e46aa1 100644 --- a/wit/query_meaning_freq.py +++ b/wit/query_meaning_freq.py @@ -1,7 +1,7 @@ import pytorch_lightning as pl import torch -from wit.model.light_module import LightModule +from model.light_module import LightModule from model.modeling_wit import ModelRunner from model.tokenization_qwen import QWenTokenizer import numpy as np