Fix train define.
This commit is contained in:
parent
90e94db2c1
commit
990e27ba15
|
@ -1,7 +1,7 @@
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from wit.model.light_module import LightModule
|
||||
from model.light_module import LightModule
|
||||
from model.modeling_wit import QWenLMHeadModel
|
||||
from model.modeling_rwkv7 import RWKVLMHeadModel
|
||||
from logger import TBLogger
|
||||
|
@ -58,7 +58,6 @@ if __name__ == "__main__":
|
|||
lit_trainer = pl.Trainer(
|
||||
accelerator="cuda",
|
||||
precision=conf.precision,
|
||||
# logger=MLFLogger("./log/", run_name=conf.name),
|
||||
logger=logger,
|
||||
strategy=conf.strategy,
|
||||
max_epochs=conf.max_epochs,
|
||||
|
|
Loading…
Reference in New Issue