Rename QwenModule to lightmodule.
This commit is contained in:
parent
1efda9fe25
commit
90e94db2c1
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from model.qwen_module import QwenModule
|
from wit.model.light_module import LightModule
|
||||||
from model.qwen_module import ModelRunner
|
from wit.model.light_module import ModelRunner
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import dataset.dataset as ds
|
import dataset.dataset as ds
|
||||||
|
@ -14,7 +14,7 @@ if __name__ == "__main__":
|
||||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
||||||
# checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt"
|
# checkpoint_path = "log/bigger/version_11/checkpoints/epoch=25-step=128336.ckpt"
|
||||||
|
|
||||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
conf = qwen.config
|
conf = qwen.config
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
|
|
|
@ -64,14 +64,14 @@ class LoadModule:
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
def __init__(self, qwen):
|
def __init__(self, model):
|
||||||
self.qwen = qwen
|
self.model = model
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def ChatTokens(self, input_ids, sample=True):
|
def ChatTokens(self, input_ids, sample=True):
|
||||||
qwen = self.qwen
|
model = self.model
|
||||||
input_ids = input_ids.to(next(qwen.parameters()).device)
|
input_ids = input_ids.to(next(model.parameters()).device)
|
||||||
outputs, loss = qwen.forward(input_ids)
|
outputs, loss = model.forward(input_ids)
|
||||||
next_token_scores = outputs[:, -1, :]
|
next_token_scores = outputs[:, -1, :]
|
||||||
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
||||||
if sample:
|
if sample:
|
||||||
|
@ -90,12 +90,12 @@ class ModelRunner:
|
||||||
system: str = "You are a helpful assistant.",
|
system: str = "You are a helpful assistant.",
|
||||||
history=[],
|
history=[],
|
||||||
):
|
):
|
||||||
qwen = self.qwen
|
model = self.model
|
||||||
history = copy.deepcopy(history)
|
history = copy.deepcopy(history)
|
||||||
self.qwen.config.pad_token_id = tokenizer.eod_id
|
self.model.config.pad_token_id = tokenizer.eod_id
|
||||||
self.qwen.config.eos_token_id = tokenizer.eod_id
|
self.model.config.eos_token_id = tokenizer.eod_id
|
||||||
raw_text, context_tokens = qwen.prepareInput(tokenizer, query, query_assistant, history, system)
|
raw_text, context_tokens = model.prepareInput(tokenizer, query, query_assistant, history, system)
|
||||||
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
|
input_ids = torch.tensor([context_tokens]).to(next(model.parameters()).device)
|
||||||
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||||
input_length = input_ids.shape[1]
|
input_length = input_ids.shape[1]
|
||||||
while True:
|
while True:
|
||||||
|
@ -126,7 +126,7 @@ class ModelRunner:
|
||||||
return make_context(tokenizer, query, query_assistant, history=history, system=system)
|
return make_context(tokenizer, query, query_assistant, history=history, system=system)
|
||||||
|
|
||||||
def repetition_penalty(self, input_ids, next_token_scores):
|
def repetition_penalty(self, input_ids, next_token_scores):
|
||||||
penalty = self.qwen.config.repetition_penalty
|
penalty = self.model.config.repetition_penalty
|
||||||
score = torch.gather(next_token_scores, 1, input_ids)
|
score = torch.gather(next_token_scores, 1, input_ids)
|
||||||
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
|
||||||
score = torch.where(score < 0, score * penalty, score / penalty)
|
score = torch.where(score < 0, score * penalty, score / penalty)
|
||||||
|
@ -134,7 +134,7 @@ class ModelRunner:
|
||||||
return next_token_scores
|
return next_token_scores
|
||||||
|
|
||||||
def top_p(self, next_token_scores):
|
def top_p(self, next_token_scores):
|
||||||
top_p = self.qwen.config.top_p
|
top_p = self.model.config.top_p
|
||||||
filter_value = -float("Inf")
|
filter_value = -float("Inf")
|
||||||
min_tokens_to_keep = 1
|
min_tokens_to_keep = 1
|
||||||
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
|
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
|
||||||
|
@ -154,8 +154,8 @@ class ModelRunner:
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def isFinish(self, next_tokens):
|
def isFinish(self, next_tokens):
|
||||||
pad_token_id = self.qwen.config.pad_token_id
|
pad_token_id = self.model.config.pad_token_id
|
||||||
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device)
|
eos_token_id_tensor = torch.tensor([self.model.config.eos_token_id]).to(next_tokens.device)
|
||||||
|
|
||||||
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
|
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
|
||||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||||
|
@ -164,7 +164,7 @@ class ModelRunner:
|
||||||
return self.unfinished_sequences.max() == 0, next_tokens[:, None]
|
return self.unfinished_sequences.max() == 0, next_tokens[:, None]
|
||||||
|
|
||||||
|
|
||||||
class QwenModule(pl.LightningModule):
|
class LightModule(pl.LightningModule):
|
||||||
def __init__(self, conf: TrainConfig, model):
|
def __init__(self, conf: TrainConfig, model):
|
||||||
self.config = conf
|
self.config = conf
|
||||||
pretrained_model_dir = conf.pretrain_model_name
|
pretrained_model_dir = conf.pretrain_model_name
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from model.qwen_module import QwenModule
|
from wit.model.light_module import LightModule
|
||||||
from model.qwen_module import ModelRunner
|
from wit.model.light_module import ModelRunner
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
@ -20,7 +20,7 @@ if __name__ == "__main__":
|
||||||
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
checkpoint_path = "log/bigger/version_3/checkpoints/epoch=46-step=231992.ckpt"
|
||||||
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
checkpoint_path = "log/bigger/version_8/checkpoints/epoch=49-step=246800.ckpt"
|
||||||
|
|
||||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
conf = qwen.config
|
conf = qwen.config
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from model.qwen_module import QwenModule
|
from wit.model.light_module import LightModule
|
||||||
from model.modeling_wit import ModelRunner
|
from model.modeling_wit import ModelRunner
|
||||||
from model.tokenization_qwen import QWenTokenizer
|
from model.tokenization_qwen import QWenTokenizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -14,7 +14,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
checkpoint_path = "log/bigger/version_1/checkpoints/epoch=14-step=74040.ckpt"
|
||||||
|
|
||||||
qwen = QwenModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
qwen = LightModule.load_from_checkpoint(checkpoint_path=checkpoint_path)
|
||||||
qwen.eval()
|
qwen.eval()
|
||||||
conf = qwen.config
|
conf = qwen.config
|
||||||
torch.manual_seed(conf.seed)
|
torch.manual_seed(conf.seed)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from model.qwen_module import QwenModule
|
from wit.model.light_module import LightModule
|
||||||
from model.modeling_wit import QWenLMHeadModel
|
from model.modeling_wit import QWenLMHeadModel
|
||||||
from model.modeling_rwkv7 import RWKVLMHeadModel
|
from model.modeling_rwkv7 import RWKVLMHeadModel
|
||||||
from logger import TBLogger
|
from logger import TBLogger
|
||||||
|
@ -45,7 +45,7 @@ if __name__ == "__main__":
|
||||||
np.random.seed(conf.seed)
|
np.random.seed(conf.seed)
|
||||||
model = QWenLMHeadModel(conf.model_config)
|
model = QWenLMHeadModel(conf.model_config)
|
||||||
# model = RWKVLMHeadModel(conf.model_config)
|
# model = RWKVLMHeadModel(conf.model_config)
|
||||||
qwen = QwenModule(conf, model)
|
qwen = LightModule(conf, model)
|
||||||
|
|
||||||
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
train_dataloader, val_dataloader = ds.InitDataset(conf)
|
||||||
# for i in range(len(train_dataloader)):
|
# for i in range(len(train_dataloader)):
|
||||||
|
|
Loading…
Reference in New Issue