Rename mask level and index.

This commit is contained in:
Colin 2025-02-21 15:33:37 +08:00
parent bca06af2dc
commit 7cf31a1f78
7 changed files with 37 additions and 93 deletions

View File

@ -42,8 +42,8 @@ class MeaningDatasetConfig:
self.level = 5 self.level = 5
self.dataset_level = 3 self.dataset_level = 3
self.min_subitem = 2 self.min_subitem = 2
self.mask_level = None self.val_mask_level = None
self.mask_idx = None self.val_mask_idx = None
class DatasetConfig: class DatasetConfig:

View File

@ -42,14 +42,14 @@ def InitDataset(config):
if os.path.exists(trainfile) and os.path.exists(valfile): if os.path.exists(trainfile) and os.path.exists(valfile):
print(f"INFO: Load dataset from {trainfile}") print(f"INFO: Load dataset from {trainfile}")
train_dataset = torch.load(trainfile, weights_only=False) train_dataset = torch.load(trainfile, weights_only=False)
train_dataset.set_mask(c.mask_level, c.mask_idx) train_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
print(f"INFO: Load dataset from {valfile}") print(f"INFO: Load dataset from {valfile}")
val_dataset = torch.load(valfile, weights_only=False) val_dataset = torch.load(valfile, weights_only=False)
val_dataset.set_mask(c.mask_level, c.mask_idx) val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
print(f"INFO: Load dataset end") print(f"INFO: Load dataset end")
else: else:
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem) raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem)
raw_dataset.set_mask(c.mask_level, c.mask_idx) raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
train_dataset, val_dataset = raw_dataset.split(0.9) train_dataset, val_dataset = raw_dataset.split(0.9)
torch.save(train_dataset, trainfile) torch.save(train_dataset, trainfile)
torch.save(val_dataset, valfile) torch.save(val_dataset, valfile)

View File

@ -268,8 +268,8 @@ class MeaningDataset(Dataset):
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache) map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
np.random.seed(seed) np.random.seed(seed)
print("Build MeaningDataset from MeaningMap.") print("Build MeaningDataset from MeaningMap.")
self.mask_level = None self.val_mask_level = None
self.mask_idx = None self.val_mask_idx = None
self.tree = [] self.tree = []
self.seq = [] self.seq = []
self.level = [] self.level = []
@ -334,13 +334,13 @@ class MeaningDataset(Dataset):
return len(self.seq) return len(self.seq)
def set_mask(self, level=None, idx=None): def set_mask(self, level=None, idx=None):
if self.mask_level is not None and self.mask_idx is not None: if self.val_mask_level is not None and self.val_mask_idx is not None:
assert len(self.mask_level) > 0, "len must > 0" assert len(self.val_mask_level) > 0, "len must > 0"
assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length" assert len(self.val_mask_level) == len(self.val_mask_idx), "mask level and mask index must be same length"
assert isinstance(self.mask_level, list), "mask level must be list" assert isinstance(self.val_mask_level, list), "mask level must be list"
assert isinstance(self.mask_idx, list), "mask index must be list" assert isinstance(self.val_mask_idx, list), "mask index must be list"
self.mask_level = level self.val_mask_level = level
self.mask_idx = idx self.val_mask_idx = idx
def __getitem__(self, idx): def __getitem__(self, idx):
return self.get_batch([idx]) return self.get_batch([idx])
@ -377,8 +377,8 @@ class MeaningDataset(Dataset):
new.rank_idx = new.rank_idx[start:end] new.rank_idx = new.rank_idx[start:end]
new.rank_all = new.rank_all[start:end] new.rank_all = new.rank_all[start:end]
new.seq_meaning = new.seq_meaning[start:end] new.seq_meaning = new.seq_meaning[start:end]
new.mask_level = self.mask_level new.val_mask_level = self.val_mask_level
new.mask_idx = self.mask_idx new.val_mask_idx = self.val_mask_idx
return new return new
def split(self, ratio): def split(self, ratio):
@ -400,13 +400,15 @@ class MeaningDataset(Dataset):
return rank_idx == (rank_all + index if index < 0 else index) return rank_idx == (rank_all + index if index < 0 else index)
def get_seq_mask_tensor(self, idx_list): def get_seq_mask_tensor(self, idx_list):
if self.mask_level is not None and self.mask_idx is not None: if self.val_mask_level is not None and self.val_mask_idx is not None:
mask = torch.tensor( mask = torch.tensor(
np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0) np.stack(
[self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0
)
) )
for i, l in enumerate(self.mask_level[1:]): for i, l in enumerate(self.val_mask_level[1:]):
mask = mask & torch.tensor( mask = mask & torch.tensor(
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0) np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0)
) )
return mask return mask
else: else:

View File

@ -1,70 +0,0 @@
import torch
import sys
from modelscope import snapshot_download
from wit.model.modeling_wit import QWenLMHeadModel
from wit.model.modeling_wit import QwenRunner
from wit.configuration import ModelConfig
from wit.model.tokenization_qwen import QWenTokenizer
from wit.model.qwen_generation_utils import (
make_context,
decode_tokens,
)
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
config = ModelConfig()
model = QWenLMHeadModel(config)
print(model)
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
sys.path.append("..")
from tools import show
def Dump_tokens_list(model):
tokens = []
for token in range(4096):
decoded, response, end_reason = decode_tokens(
[token],
tokenizer,
raw_text_len=0,
context_length=0,
errors="replace",
)
tokens.append(str(token).zfill(7) + ": " + repr(decoded))
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
Dump_tokens_list(model)
model = model.from_pretrained(model_dir).cuda()
# state = model.state_dict()
# torch.save(state, "model_params.pth")
# model.load_state_dict(torch.load('model_params.pth'))
model = model.eval()
# model = model.train() # control by @torch.no_grad()
runner = QwenRunner(model)
output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20)
print(decode_tokens)
for i, token in enumerate(output_ids):
de = tokenizer.decode([token])
de = str(i + 1).zfill(3) + " : " + repr(de)
print(de)

View File

@ -17,7 +17,7 @@ meaning数据集是一个模仿自然语言以及抽象表达的数据集。
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层index=-1:最后一个index=0:第一个 11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层index=-1:最后一个index=0:第一个
12. meaning_height 当前meaning的总高度 12. meaning_height 当前meaning的总高度
13. meaning_weight 当前meaning的总宽度 13. meaning_weight 当前meaning的总宽度
14. mask_level mask_idx: 表示用于训练的token的mask,mask_level=[0, 1, 2] mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
``` ```

View File

@ -195,6 +195,18 @@ class QwenRunner:
self.qwen = qwen self.qwen = qwen
# torch.backends.cuda.enable_flash_sdp(True) # torch.backends.cuda.enable_flash_sdp(True)
@torch.no_grad()
def ChatToken(self, input_ids):
qwen = self.qwen
input_ids = input_ids.to(next(qwen.parameters()).device)
outputs, loss = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
next_token_scores = self.top_p(next_token_scores)
next_tokens = self.sample(next_token_scores)
return next_tokens
@torch.no_grad() @torch.no_grad()
def Chat( def Chat(
self, self,
@ -214,7 +226,7 @@ class QwenRunner:
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
input_length = input_ids.shape[1] input_length = input_ids.shape[1]
while True: while True:
outputs = self.forwardQWen(input_ids) outputs, loss = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :] next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores) next_token_scores = self.repetition_penalty(input_ids, next_token_scores)

View File

@ -9,7 +9,7 @@ from model.modeling_wit import QWenLMHeadModel
from configuration import ModelConfig, TrainConfig from configuration import ModelConfig, TrainConfig
class LitModule(pl.LightningModule): class QwenModule(pl.LightningModule):
def __init__(self, conf: TrainConfig = None): def __init__(self, conf: TrainConfig = None):
pretrained_model_dir = conf.pretrain_model_name pretrained_model_dir = conf.pretrain_model_name
learning_rate = conf.learning_rate learning_rate = conf.learning_rate