Rename QwenModule to lightmodule.
This commit is contained in:
parent
1efda9fe25
commit
90e94db2c1
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from model.qwen_module import QwenModule
|
||||
from model.qwen_module import ModelRunner
|
||||
from wit.model.light_module import LightModule
|
||||
from wit.model.light_module import ModelRunner
|
||||
import numpy as np
|
||||
|
||||
import dataset.dataset as ds
|
||||
|
@ -14,7 +14,7 @@ if __name__ == "__main__":
|
|||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
||||
# checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt"
|
||||
|
||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
|
|
|
@ -64,14 +64,14 @@ class LoadModule:
|
|||
|
||||
|
||||
class ModelRunner:
|
||||
def __init__(self, qwen):
|
||||
self.qwen = qwen
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
@torch.no_grad()
|
||||
def ChatTokens(self, input_ids, sample=True):
|
||||
qwen = self.qwen
|
||||
input_ids = input_ids.to(next(qwen.parameters()).device)
|
||||
outputs, loss = qwen.forward(input_ids)
|
||||
model = self.model
|
||||
input_ids = input_ids.to(next(model.parameters()).device)
|
||||
outputs, loss = model.forward(input_ids)
|
||||
next_token_scores = outputs[:, -1, :]
|
||||
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
||||
if sample:
|
||||
|
@ -90,12 +90,12 @@ class ModelRunner:
|
|||
system: str = "You are a helpful assistant.",
|
||||
history=[],
|
||||
):
|
||||
qwen = self.qwen
|
||||
model = self.model
|
||||
history = copy.deepcopy(history)
|
||||
self.qwen.config.pad_token_id = tokenizer.eod_id
|
||||
self.qwen.config.eos_token_id = tokenizer.eod_id
|
||||
raw_text, context_tokens = qwen.prepareInput(tokenizer, query, query_assistant, history, system)
|
||||
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
|
||||
self.model.config.pad_token_id = tokenizer.eod_id
|
||||
self.model.config.eos_token_id = tokenizer.eod_id
|
||||
raw_text, context_tokens = model.prepareInput(tokenizer, query, query_assistant, history, system)
|
||||
input_ids = torch.tensor([context_tokens]).to(next(model.parameters()).device)
|
||||
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
input_length = input_ids.shape[1]
|
||||
while True:
|
||||
|
@ -126,7 +126,7 @@ class ModelRunner:
|
|||
return make_context(tokenizer, query, query_assistant, history=history, system=system)
|
||||
|
||||
def repetition_penalty(self, input_ids, next_token_scores):
|
||||
penalty = self.qwen.config.repetition_penalty
|
||||
penalty = self.model.config.repetition_penalty
|
||||
score = torch.gather(next_token_scores, 1, input_ids)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||
score = torch.where(score < 0, score * penalty, score / penalty)
|
||||
|
@ -134,7 +134,7 @@ class ModelRunner:
|
|||
return next_token_scores
|
||||
|
||||
def top_p(self, next_token_scores):
|
||||
top_p = self.qwen.config.top_p
|
||||
top_p = self.model.config.top_p
|
||||
filter_value = -float("Inf")
|
||||
min_tokens_to_keep = 1
|
||||
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
|
||||
|
@ -154,8 +154,8 @@ class ModelRunner:
|
|||
return next_tokens
|
||||
|
||||
def isFinish(self, next_tokens):
|
||||
pad_token_id = self.qwen.config.pad_token_id
|
||||
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device)
|
||||
pad_token_id = self.model.config.pad_token_id
|
||||
eos_token_id_tensor = torch.tensor([self.model.config.eos_token_id]).to(next_tokens.device)
|
||||
|
||||
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
|
@ -164,7 +164,7 @@ class ModelRunner:
|
|||
return self.unfinished_sequences.max() == 0, next_tokens[:, None]
|
||||
|
||||
|
||||
class QwenModule(pl.LightningModule):
|
||||
class LightModule(pl.LightningModule):
|
||||
def __init__(self, conf: TrainConfig, model):
|
||||
self.config = conf
|
||||
pretrained_model_dir = conf.pretrain_model_name
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from model.qwen_module import QwenModule
|
||||
from model.qwen_module import ModelRunner
|
||||
from wit.model.light_module import LightModule
|
||||
from wit.model.light_module import ModelRunner
|
||||
import numpy as np
|
||||
|
||||
import math
|
||||
|
@ -20,7 +20,7 @@ if __name__ == "__main__":
|
|||
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
||||
|
||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from model.qwen_module import QwenModule
|
||||
from wit.model.light_module import LightModule
|
||||
from model.modeling_wit import ModelRunner
|
||||
from model.tokenization_qwen import QWenTokenizer
|
||||
import numpy as np
|
||||
|
@ -14,7 +14,7 @@ if __name__ == "__main__":
|
|||
|
||||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
||||
|
||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||
qwen.eval()
|
||||
conf = qwen.config
|
||||
torch.manual_seed(conf.seed)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from model.qwen_module import QwenModule
|
||||
from wit.model.light_module import LightModule
|
||||
from model.modeling_wit import QWenLMHeadModel
|
||||
from model.modeling_rwkv7 import RWKVLMHeadModel
|
||||
from logger import TBLogger
|
||||
|
@ -45,7 +45,7 @@ if __name__ == "__main__":
|
|||
np.random.seed(conf.seed)
|
||||
model = QWenLMHeadModel(conf.model_config)
|
||||
# model = RWKVLMHeadModel(conf.model_config)
|
||||
qwen = QwenModule(conf, model)
|
||||
qwen = LightModule(conf, model)
|
||||
|
||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||
# for i in range(len(train_dataloader)):
|
||||
|
|
Loading…
Reference in New Issue