diff --git a/wit/inference.py b/wit/inference.py index e9dc575..48dfbc6 100644 --- a/wit/inference.py +++ b/wit/inference.py @@ -1,7 +1,7 @@ import torch -from model.qwen_module import QwenModule -from model.qwen_module import ModelRunner +from wit.model.light_module import LightModule +from wit.model.light_module import ModelRunner import numpy as np import dataset.dataset as ds @@ -14,7 +14,7 @@ if __name__ == "__main__": checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" # checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt" - qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) + qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() conf = qwen.config torch.manual_seed(conf.seed) diff --git a/wit/model/qwen_module.py b/wit/model/light_module.py similarity index 92% rename from wit/model/qwen_module.py rename to wit/model/light_module.py index 3f2177d..e766746 100644 --- a/wit/model/qwen_module.py +++ b/wit/model/light_module.py @@ -64,14 +64,14 @@ class LoadModule: class ModelRunner: - def __init__(self, qwen): - self.qwen = qwen + def __init__(self, model): + self.model = model @torch.no_grad() def ChatTokens(self, input_ids, sample=True): - qwen = self.qwen - input_ids = input_ids.to(next(qwen.parameters()).device) - outputs, loss = qwen.forward(input_ids) + model = self.model + input_ids = input_ids.to(next(model.parameters()).device) + outputs, loss = model.forward(input_ids) next_token_scores = outputs[:, -1, :] next_token_scores = self.repetition_penalty(input_ids, next_token_scores) if sample: @@ -90,12 +90,12 @@ class ModelRunner: system: str = "You are a helpful assistant.", history=[], ): - qwen = self.qwen + model = self.model history = copy.deepcopy(history) - self.qwen.config.pad_token_id = tokenizer.eod_id - self.qwen.config.eos_token_id = tokenizer.eod_id - raw_text, context_tokens = qwen.prepareInput(tokenizer, query, query_assistant, history, system) - input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device) + self.model.config.pad_token_id = tokenizer.eod_id + self.model.config.eos_token_id = tokenizer.eod_id + raw_text, context_tokens = model.prepareInput(tokenizer, query, query_assistant, history, system) + input_ids = torch.tensor([context_tokens]).to(next(model.parameters()).device) self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) input_length = input_ids.shape[1] while True: @@ -126,7 +126,7 @@ class ModelRunner: return make_context(tokenizer, query, query_assistant, history=history, system=system) def repetition_penalty(self, input_ids, next_token_scores): - penalty = self.qwen.config.repetition_penalty + penalty = self.model.config.repetition_penalty score = torch.gather(next_token_scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities score = torch.where(score < 0, score * penalty, score / penalty) @@ -134,7 +134,7 @@ class ModelRunner: return next_token_scores def top_p(self, next_token_scores): - top_p = self.qwen.config.top_p + top_p = self.model.config.top_p filter_value = -float("Inf") min_tokens_to_keep = 1 sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False) @@ -154,8 +154,8 @@ class ModelRunner: return next_tokens def isFinish(self, next_tokens): - pad_token_id = self.qwen.config.pad_token_id - eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device) + pad_token_id = self.model.config.pad_token_id + eos_token_id_tensor = torch.tensor([self.model.config.eos_token_id]).to(next_tokens.device) next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences) self.unfinished_sequences = self.unfinished_sequences.mul( @@ -164,7 +164,7 @@ class ModelRunner: return self.unfinished_sequences.max() == 0, next_tokens[:, None] -class QwenModule(pl.LightningModule): +class LightModule(pl.LightningModule): def __init__(self, conf: TrainConfig, model): self.config = conf pretrained_model_dir = conf.pretrain_model_name diff --git a/wit/query_block_output.py b/wit/query_block_output.py index 1b04e45..3c6398d 100644 --- a/wit/query_block_output.py +++ b/wit/query_block_output.py @@ -1,7 +1,7 @@ import torch -from model.qwen_module import QwenModule -from model.qwen_module import ModelRunner +from wit.model.light_module import LightModule +from wit.model.light_module import ModelRunner import numpy as np import math @@ -20,7 +20,7 @@ if __name__ == "__main__": checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt" checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt" - qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) + qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() conf = qwen.config torch.manual_seed(conf.seed) diff --git a/wit/query_meaning_freq.py b/wit/query_meaning_freq.py index f93556e..0067157 100644 --- a/wit/query_meaning_freq.py +++ b/wit/query_meaning_freq.py @@ -1,7 +1,7 @@ import pytorch_lightning as pl import torch -from model.qwen_module import QwenModule +from wit.model.light_module import LightModule from model.modeling_wit import ModelRunner from model.tokenization_qwen import QWenTokenizer import numpy as np @@ -14,7 +14,7 @@ if __name__ == "__main__": checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt" - qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path) + qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path) qwen.eval() conf = qwen.config torch.manual_seed(conf.seed) diff --git a/wit/train.py b/wit/train.py index fea5f67..5afe410 100644 --- a/wit/train.py +++ b/wit/train.py @@ -1,7 +1,7 @@ import pytorch_lightning as pl import torch -from model.qwen_module import QwenModule +from wit.model.light_module import LightModule from model.modeling_wit import QWenLMHeadModel from model.modeling_rwkv7 import RWKVLMHeadModel from logger import TBLogger @@ -45,7 +45,7 @@ if __name__ == "__main__": np.random.seed(conf.seed) model = QWenLMHeadModel(conf.model_config) # model = RWKVLMHeadModel(conf.model_config) - qwen = QwenModule(conf, model) + qwen = LightModule(conf, model) train_dataloader, val_dataloader = ds.InitDataset(conf) # for i in range(len(train_dataloader)):