Fix train define.
This commit is contained in:
		
							parent
							
								
									90e94db2c1
								
							
						
					
					
						commit
						990e27ba15
					
				| 
						 | 
					@ -1,7 +1,7 @@
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from wit.model.light_module import LightModule
 | 
					from model.light_module import LightModule
 | 
				
			||||||
from model.modeling_wit import QWenLMHeadModel
 | 
					from model.modeling_wit import QWenLMHeadModel
 | 
				
			||||||
from model.modeling_rwkv7 import RWKVLMHeadModel
 | 
					from model.modeling_rwkv7 import RWKVLMHeadModel
 | 
				
			||||||
from logger import TBLogger
 | 
					from logger import TBLogger
 | 
				
			||||||
| 
						 | 
					@ -58,7 +58,6 @@ if __name__ == "__main__":
 | 
				
			||||||
    lit_trainer = pl.Trainer(
 | 
					    lit_trainer = pl.Trainer(
 | 
				
			||||||
        accelerator="cuda",
 | 
					        accelerator="cuda",
 | 
				
			||||||
        precision=conf.precision,
 | 
					        precision=conf.precision,
 | 
				
			||||||
        # logger=MLFLogger("./log/", run_name=conf.name),
 | 
					 | 
				
			||||||
        logger=logger,
 | 
					        logger=logger,
 | 
				
			||||||
        strategy=conf.strategy,
 | 
					        strategy=conf.strategy,
 | 
				
			||||||
        max_epochs=conf.max_epochs,
 | 
					        max_epochs=conf.max_epochs,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue