Fix train define.
This commit is contained in:
		
							parent
							
								
									90e94db2c1
								
							
						
					
					
						commit
						990e27ba15
					
				| 
						 | 
				
			
			@ -1,7 +1,7 @@
 | 
			
		|||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from wit.model.light_module import LightModule
 | 
			
		||||
from model.light_module import LightModule
 | 
			
		||||
from model.modeling_wit import QWenLMHeadModel
 | 
			
		||||
from model.modeling_rwkv7 import RWKVLMHeadModel
 | 
			
		||||
from logger import TBLogger
 | 
			
		||||
| 
						 | 
				
			
			@ -58,7 +58,6 @@ if __name__ == "__main__":
 | 
			
		|||
    lit_trainer = pl.Trainer(
 | 
			
		||||
        accelerator="cuda",
 | 
			
		||||
        precision=conf.precision,
 | 
			
		||||
        # logger=MLFLogger("./log/", run_name=conf.name),
 | 
			
		||||
        logger=logger,
 | 
			
		||||
        strategy=conf.strategy,
 | 
			
		||||
        max_epochs=conf.max_epochs,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue