Fix train define.

This commit is contained in:
Colin 2025-03-12 20:02:02 +08:00
parent 90e94db2c1
commit 990e27ba15
1 changed files with 1 additions and 2 deletions

View File

@ -1,7 +1,7 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from wit.model.light_module import LightModule from model.light_module import LightModule
from model.modeling_wit import QWenLMHeadModel from model.modeling_wit import QWenLMHeadModel
from model.modeling_rwkv7 import RWKVLMHeadModel from model.modeling_rwkv7 import RWKVLMHeadModel
from logger import TBLogger from logger import TBLogger
@ -58,7 +58,6 @@ if __name__ == "__main__":
lit_trainer = pl.Trainer( lit_trainer = pl.Trainer(
accelerator="cuda", accelerator="cuda",
precision=conf.precision, precision=conf.precision,
# logger=MLFLogger("./log/", run_name=conf.name),
logger=logger, logger=logger,
strategy=conf.strategy, strategy=conf.strategy,
max_epochs=conf.max_epochs, max_epochs=conf.max_epochs,