diff --git a/wit/train.py b/wit/train.py index 5afe410..0b65ca9 100644 --- a/wit/train.py +++ b/wit/train.py @@ -1,7 +1,7 @@ import pytorch_lightning as pl import torch -from wit.model.light_module import LightModule +from model.light_module import LightModule from model.modeling_wit import QWenLMHeadModel from model.modeling_rwkv7 import RWKVLMHeadModel from logger import TBLogger @@ -58,7 +58,6 @@ if __name__ == "__main__": lit_trainer = pl.Trainer( accelerator="cuda", precision=conf.precision, - # logger=MLFLogger("./log/", run_name=conf.name), logger=logger, strategy=conf.strategy, max_epochs=conf.max_epochs,