Fix train define.
This commit is contained in:
parent
90e94db2c1
commit
990e27ba15
wit
|
@ -1,7 +1,7 @@
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from wit.model.light_module import LightModule
|
from model.light_module import LightModule
|
||||||
from model.modeling_wit import QWenLMHeadModel
|
from model.modeling_wit import QWenLMHeadModel
|
||||||
from model.modeling_rwkv7 import RWKVLMHeadModel
|
from model.modeling_rwkv7 import RWKVLMHeadModel
|
||||||
from logger import TBLogger
|
from logger import TBLogger
|
||||||
|
@ -58,7 +58,6 @@ if __name__ == "__main__":
|
||||||
lit_trainer = pl.Trainer(
|
lit_trainer = pl.Trainer(
|
||||||
accelerator="cuda",
|
accelerator="cuda",
|
||||||
precision=conf.precision,
|
precision=conf.precision,
|
||||||
# logger=MLFLogger("./log/", run_name=conf.name),
|
|
||||||
logger=logger,
|
logger=logger,
|
||||||
strategy=conf.strategy,
|
strategy=conf.strategy,
|
||||||
max_epochs=conf.max_epochs,
|
max_epochs=conf.max_epochs,
|
||||||
|
|
Loading…
Reference in New Issue