diff --git a/wit/configuration.py b/wit/configuration.py index d0f63b3..be0c71e 100644 --- a/wit/configuration.py +++ b/wit/configuration.py @@ -42,8 +42,8 @@ class MeaningDatasetConfig: self.level = 5 self.dataset_level = 3 self.min_subitem = 2 - self.mask_level = None - self.mask_idx = None + self.val_mask_level = None + self.val_mask_idx = None class DatasetConfig: diff --git a/wit/dataset/dataset.py b/wit/dataset/dataset.py index 842567a..3ed5496 100644 --- a/wit/dataset/dataset.py +++ b/wit/dataset/dataset.py @@ -42,14 +42,14 @@ def InitDataset(config): if os.path.exists(trainfile) and os.path.exists(valfile): print(f"INFO: Load dataset from {trainfile}") train_dataset = torch.load(trainfile, weights_only=False) - train_dataset.set_mask(c.mask_level, c.mask_idx) + train_dataset.set_mask(c.val_mask_level, c.val_mask_idx) print(f"INFO: Load dataset from {valfile}") val_dataset = torch.load(valfile, weights_only=False) - val_dataset.set_mask(c.mask_level, c.mask_idx) + val_dataset.set_mask(c.val_mask_level, c.val_mask_idx) print(f"INFO: Load dataset end") else: raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem) - raw_dataset.set_mask(c.mask_level, c.mask_idx) + raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx) train_dataset, val_dataset = raw_dataset.split(0.9) torch.save(train_dataset, trainfile) torch.save(val_dataset, valfile) diff --git a/wit/dataset/meaning_dataset.py b/wit/dataset/meaning_dataset.py index f94cc10..2289bdb 100644 --- a/wit/dataset/meaning_dataset.py +++ b/wit/dataset/meaning_dataset.py @@ -268,8 +268,8 @@ class MeaningDataset(Dataset): map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache) np.random.seed(seed) print("Build MeaningDataset from MeaningMap.") - self.mask_level = None - self.mask_idx = None + self.val_mask_level = None + self.val_mask_idx = None self.tree = [] self.seq = [] self.level = [] @@ -334,13 +334,13 @@ class MeaningDataset(Dataset): return len(self.seq) def set_mask(self, level=None, idx=None): - if self.mask_level is not None and self.mask_idx is not None: - assert len(self.mask_level) > 0, "len must > 0" - assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length" - assert isinstance(self.mask_level, list), "mask level must be list" - assert isinstance(self.mask_idx, list), "mask index must be list" - self.mask_level = level - self.mask_idx = idx + if self.val_mask_level is not None and self.val_mask_idx is not None: + assert len(self.val_mask_level) > 0, "len must > 0" + assert len(self.val_mask_level) == len(self.val_mask_idx), "mask level and mask index must be same length" + assert isinstance(self.val_mask_level, list), "mask level must be list" + assert isinstance(self.val_mask_idx, list), "mask index must be list" + self.val_mask_level = level + self.val_mask_idx = idx def __getitem__(self, idx): return self.get_batch([idx]) @@ -377,8 +377,8 @@ class MeaningDataset(Dataset): new.rank_idx = new.rank_idx[start:end] new.rank_all = new.rank_all[start:end] new.seq_meaning = new.seq_meaning[start:end] - new.mask_level = self.mask_level - new.mask_idx = self.mask_idx + new.val_mask_level = self.val_mask_level + new.val_mask_idx = self.val_mask_idx return new def split(self, ratio): @@ -400,13 +400,15 @@ class MeaningDataset(Dataset): return rank_idx == (rank_all + index if index < 0 else index) def get_seq_mask_tensor(self, idx_list): - if self.mask_level is not None and self.mask_idx is not None: + if self.val_mask_level is not None and self.val_mask_idx is not None: mask = torch.tensor( - np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0) + np.stack( + [self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0 + ) ) - for i, l in enumerate(self.mask_level[1:]): + for i, l in enumerate(self.val_mask_level[1:]): mask = mask & torch.tensor( - np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0) + np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0) ) return mask else: diff --git a/wit/demo.py b/wit/demo.py deleted file mode 100644 index a5cd2b9..0000000 --- a/wit/demo.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import sys -from modelscope import snapshot_download - -from wit.model.modeling_wit import QWenLMHeadModel -from wit.model.modeling_wit import QwenRunner -from wit.configuration import ModelConfig -from wit.model.tokenization_qwen import QWenTokenizer - - -from wit.model.qwen_generation_utils import ( - make_context, - decode_tokens, -) - -seed = 4321 -torch.manual_seed(seed) -torch.cuda.manual_seed_all(seed) - -model_dir = snapshot_download("qwen/Qwen-1_8B-Chat") -# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat" - -config = ModelConfig() -model = QWenLMHeadModel(config) - -print(model) - -tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken") - -sys.path.append("..") -from tools import show - - -def Dump_tokens_list(model): - tokens = [] - for token in range(4096): - decoded, response, end_reason = decode_tokens( - [token], - tokenizer, - raw_text_len=0, - context_length=0, - errors="replace", - ) - tokens.append(str(token).zfill(7) + ": " + repr(decoded)) - show.DumpListToFile(tokens, "./temp/qwen_token_list.txt") - - -Dump_tokens_list(model) - - -model = model.from_pretrained(model_dir).cuda() - -# state = model.state_dict() -# torch.save(state, "model_params.pth") -# model.load_state_dict(torch.load('model_params.pth')) - - -model = model.eval() -# model = model.train() # control by @torch.no_grad() - - -runner = QwenRunner(model) - -output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20) -print(decode_tokens) - -for i, token in enumerate(output_ids): - de = tokenizer.decode([token]) - de = str(i + 1).zfill(3) + " : " + repr(de) - print(de) diff --git a/wit/doc/meaning_dataset.md b/wit/doc/meaning_dataset.md index 62e5bd5..772147c 100644 --- a/wit/doc/meaning_dataset.md +++ b/wit/doc/meaning_dataset.md @@ -17,7 +17,7 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。 11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个 12. meaning_height 当前meaning的总高度 13. meaning_weight 当前meaning的总宽度 -14. mask_level mask_idx: 表示用于训练的token的mask,mask_level=[0, 1, 2] mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 +14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练 ``` diff --git a/wit/model/modeling_wit.py b/wit/model/modeling_wit.py index 94ec99f..fac0b7c 100644 --- a/wit/model/modeling_wit.py +++ b/wit/model/modeling_wit.py @@ -195,6 +195,18 @@ class QwenRunner: self.qwen = qwen # torch.backends.cuda.enable_flash_sdp(True) + @torch.no_grad() + def ChatToken(self, input_ids): + qwen = self.qwen + input_ids = input_ids.to(next(qwen.parameters()).device) + outputs, loss = self.forwardQWen(input_ids) + next_token_scores = outputs[:, -1, :] + + next_token_scores = self.repetition_penalty(input_ids, next_token_scores) + next_token_scores = self.top_p(next_token_scores) + next_tokens = self.sample(next_token_scores) + return next_tokens + @torch.no_grad() def Chat( self, @@ -214,7 +226,7 @@ class QwenRunner: self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) input_length = input_ids.shape[1] while True: - outputs = self.forwardQWen(input_ids) + outputs, loss = self.forwardQWen(input_ids) next_token_scores = outputs[:, -1, :] next_token_scores = self.repetition_penalty(input_ids, next_token_scores) diff --git a/wit/model/lit_module.py b/wit/model/qwen_module.py similarity index 99% rename from wit/model/lit_module.py rename to wit/model/qwen_module.py index e2303da..9a71dae 100644 --- a/wit/model/lit_module.py +++ b/wit/model/qwen_module.py @@ -9,7 +9,7 @@ from model.modeling_wit import QWenLMHeadModel from configuration import ModelConfig, TrainConfig -class LitModule(pl.LightningModule): +class QwenModule(pl.LightningModule): def __init__(self, conf: TrainConfig = None): pretrained_model_dir = conf.pretrain_model_name learning_rate = conf.learning_rate