Rename QwenModule to lightmodule.
This commit is contained in:
		
							parent
							
								
									1efda9fe25
								
							
						
					
					
						commit
						90e94db2c1
					
				|  | @ -1,7 +1,7 @@ | |||
| import torch | ||||
| 
 | ||||
| from model.qwen_module import QwenModule | ||||
| from model.qwen_module import ModelRunner | ||||
| from wit.model.light_module import LightModule | ||||
| from wit.model.light_module import ModelRunner | ||||
| import numpy as np | ||||
| 
 | ||||
| import dataset.dataset as ds | ||||
|  | @ -14,7 +14,7 @@ if __name__ == "__main__": | |||
|     checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" | ||||
|     # checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt" | ||||
| 
 | ||||
|     qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) | ||||
|     qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) | ||||
|     qwen.eval() | ||||
|     conf = qwen.config | ||||
|     torch.manual_seed(conf.seed) | ||||
|  |  | |||
|  | @ -64,14 +64,14 @@ class LoadModule: | |||
| 
 | ||||
| 
 | ||||
| class ModelRunner: | ||||
|     def __init__(self, qwen): | ||||
|         self.qwen = qwen | ||||
|     def __init__(self, model): | ||||
|         self.model = model | ||||
| 
 | ||||
|     @torch.no_grad() | ||||
|     def ChatTokens(self, input_ids, sample=True): | ||||
|         qwen = self.qwen | ||||
|         input_ids = input_ids.to(next(qwen.parameters()).device) | ||||
|         outputs, loss = qwen.forward(input_ids) | ||||
|         model = self.model | ||||
|         input_ids = input_ids.to(next(model.parameters()).device) | ||||
|         outputs, loss = model.forward(input_ids) | ||||
|         next_token_scores = outputs[:, -1, :] | ||||
|         next_token_scores = self.repetition_penalty(input_ids, next_token_scores) | ||||
|         if sample: | ||||
|  | @ -90,12 +90,12 @@ class ModelRunner: | |||
|         system: str = "You are a helpful assistant.", | ||||
|         history=[], | ||||
|     ): | ||||
|         qwen = self.qwen | ||||
|         model = self.model | ||||
|         history = copy.deepcopy(history) | ||||
|         self.qwen.config.pad_token_id = tokenizer.eod_id | ||||
|         self.qwen.config.eos_token_id = tokenizer.eod_id | ||||
|         raw_text, context_tokens = qwen.prepareInput(tokenizer, query, query_assistant, history, system) | ||||
|         input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device) | ||||
|         self.model.config.pad_token_id = tokenizer.eod_id | ||||
|         self.model.config.eos_token_id = tokenizer.eod_id | ||||
|         raw_text, context_tokens = model.prepareInput(tokenizer, query, query_assistant, history, system) | ||||
|         input_ids = torch.tensor([context_tokens]).to(next(model.parameters()).device) | ||||
|         self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) | ||||
|         input_length = input_ids.shape[1] | ||||
|         while True: | ||||
|  | @ -126,7 +126,7 @@ class ModelRunner: | |||
|         return make_context(tokenizer, query, query_assistant, history=history, system=system) | ||||
| 
 | ||||
|     def repetition_penalty(self, input_ids, next_token_scores): | ||||
|         penalty = self.qwen.config.repetition_penalty | ||||
|         penalty = self.model.config.repetition_penalty | ||||
|         score = torch.gather(next_token_scores, 1, input_ids) | ||||
|         # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities | ||||
|         score = torch.where(score < 0, score * penalty, score / penalty) | ||||
|  | @ -134,7 +134,7 @@ class ModelRunner: | |||
|         return next_token_scores | ||||
| 
 | ||||
|     def top_p(self, next_token_scores): | ||||
|         top_p = self.qwen.config.top_p | ||||
|         top_p = self.model.config.top_p | ||||
|         filter_value = -float("Inf") | ||||
|         min_tokens_to_keep = 1 | ||||
|         sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False) | ||||
|  | @ -154,8 +154,8 @@ class ModelRunner: | |||
|         return next_tokens | ||||
| 
 | ||||
|     def isFinish(self, next_tokens): | ||||
|         pad_token_id = self.qwen.config.pad_token_id | ||||
|         eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device) | ||||
|         pad_token_id = self.model.config.pad_token_id | ||||
|         eos_token_id_tensor = torch.tensor([self.model.config.eos_token_id]).to(next_tokens.device) | ||||
| 
 | ||||
|         next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences) | ||||
|         self.unfinished_sequences = self.unfinished_sequences.mul( | ||||
|  | @ -164,7 +164,7 @@ class ModelRunner: | |||
|         return self.unfinished_sequences.max() == 0, next_tokens[:, None] | ||||
| 
 | ||||
| 
 | ||||
| class QwenModule(pl.LightningModule): | ||||
| class LightModule(pl.LightningModule): | ||||
|     def __init__(self, conf: TrainConfig, model): | ||||
|         self.config = conf | ||||
|         pretrained_model_dir = conf.pretrain_model_name | ||||
|  | @ -1,7 +1,7 @@ | |||
| import torch | ||||
| 
 | ||||
| from model.qwen_module import QwenModule | ||||
| from model.qwen_module import ModelRunner | ||||
| from wit.model.light_module import LightModule | ||||
| from wit.model.light_module import ModelRunner | ||||
| import numpy as np | ||||
| 
 | ||||
| import math | ||||
|  | @ -20,7 +20,7 @@ if __name__ == "__main__": | |||
|     checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt" | ||||
|     checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" | ||||
| 
 | ||||
|     qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) | ||||
|     qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) | ||||
|     qwen.eval() | ||||
|     conf = qwen.config | ||||
|     torch.manual_seed(conf.seed) | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| 
 | ||||
| from model.qwen_module import QwenModule | ||||
| from wit.model.light_module import LightModule | ||||
| from model.modeling_wit import ModelRunner | ||||
| from model.tokenization_qwen import QWenTokenizer | ||||
| import numpy as np | ||||
|  | @ -14,7 +14,7 @@ if __name__ == "__main__": | |||
| 
 | ||||
|     checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" | ||||
| 
 | ||||
|     qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) | ||||
|     qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) | ||||
|     qwen.eval() | ||||
|     conf = qwen.config | ||||
|     torch.manual_seed(conf.seed) | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| 
 | ||||
| from model.qwen_module import QwenModule | ||||
| from wit.model.light_module import LightModule | ||||
| from model.modeling_wit import QWenLMHeadModel | ||||
| from model.modeling_rwkv7 import RWKVLMHeadModel | ||||
| from logger import TBLogger | ||||
|  | @ -45,7 +45,7 @@ if __name__ == "__main__": | |||
|     np.random.seed(conf.seed) | ||||
|     model = QWenLMHeadModel(conf.model_config) | ||||
|     # model = RWKVLMHeadModel(conf.model_config) | ||||
|     qwen = QwenModule(conf, model) | ||||
|     qwen = LightModule(conf, model) | ||||
| 
 | ||||
|     train_dataloader, val_dataloader = ds.InitDataset(conf) | ||||
|     # for i in range(len(train_dataloader)): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue