Rename mask level and index.
This commit is contained in:
parent
bca06af2dc
commit
7cf31a1f78
|
@ -42,8 +42,8 @@ class MeaningDatasetConfig:
|
|||
self.level = 5
|
||||
self.dataset_level = 3
|
||||
self.min_subitem = 2
|
||||
self.mask_level = None
|
||||
self.mask_idx = None
|
||||
self.val_mask_level = None
|
||||
self.val_mask_idx = None
|
||||
|
||||
|
||||
class DatasetConfig:
|
||||
|
|
|
@ -42,14 +42,14 @@ def InitDataset(config):
|
|||
if os.path.exists(trainfile) and os.path.exists(valfile):
|
||||
print(f"INFO: Load dataset from {trainfile}")
|
||||
train_dataset = torch.load(trainfile, weights_only=False)
|
||||
train_dataset.set_mask(c.mask_level, c.mask_idx)
|
||||
train_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
print(f"INFO: Load dataset from {valfile}")
|
||||
val_dataset = torch.load(valfile, weights_only=False)
|
||||
val_dataset.set_mask(c.mask_level, c.mask_idx)
|
||||
val_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
print(f"INFO: Load dataset end")
|
||||
else:
|
||||
raw_dataset = MeaningDataset(start, start + size, vocab, None, c.level_ratio, c.min_subitem)
|
||||
raw_dataset.set_mask(c.mask_level, c.mask_idx)
|
||||
raw_dataset.set_mask(c.val_mask_level, c.val_mask_idx)
|
||||
train_dataset, val_dataset = raw_dataset.split(0.9)
|
||||
torch.save(train_dataset, trainfile)
|
||||
torch.save(val_dataset, valfile)
|
||||
|
|
|
@ -268,8 +268,8 @@ class MeaningDataset(Dataset):
|
|||
map = MeaningMap(end, vocab_size, max_subitem, min_subitem, use_cache=use_cache)
|
||||
np.random.seed(seed)
|
||||
print("Build MeaningDataset from MeaningMap.")
|
||||
self.mask_level = None
|
||||
self.mask_idx = None
|
||||
self.val_mask_level = None
|
||||
self.val_mask_idx = None
|
||||
self.tree = []
|
||||
self.seq = []
|
||||
self.level = []
|
||||
|
@ -334,13 +334,13 @@ class MeaningDataset(Dataset):
|
|||
return len(self.seq)
|
||||
|
||||
def set_mask(self, level=None, idx=None):
|
||||
if self.mask_level is not None and self.mask_idx is not None:
|
||||
assert len(self.mask_level) > 0, "len must > 0"
|
||||
assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length"
|
||||
assert isinstance(self.mask_level, list), "mask level must be list"
|
||||
assert isinstance(self.mask_idx, list), "mask index must be list"
|
||||
self.mask_level = level
|
||||
self.mask_idx = idx
|
||||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
||||
assert len(self.val_mask_level) > 0, "len must > 0"
|
||||
assert len(self.val_mask_level) == len(self.val_mask_idx), "mask level and mask index must be same length"
|
||||
assert isinstance(self.val_mask_level, list), "mask level must be list"
|
||||
assert isinstance(self.val_mask_idx, list), "mask index must be list"
|
||||
self.val_mask_level = level
|
||||
self.val_mask_idx = idx
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.get_batch([idx])
|
||||
|
@ -377,8 +377,8 @@ class MeaningDataset(Dataset):
|
|||
new.rank_idx = new.rank_idx[start:end]
|
||||
new.rank_all = new.rank_all[start:end]
|
||||
new.seq_meaning = new.seq_meaning[start:end]
|
||||
new.mask_level = self.mask_level
|
||||
new.mask_idx = self.mask_idx
|
||||
new.val_mask_level = self.val_mask_level
|
||||
new.val_mask_idx = self.val_mask_idx
|
||||
return new
|
||||
|
||||
def split(self, ratio):
|
||||
|
@ -400,13 +400,15 @@ class MeaningDataset(Dataset):
|
|||
return rank_idx == (rank_all + index if index < 0 else index)
|
||||
|
||||
def get_seq_mask_tensor(self, idx_list):
|
||||
if self.mask_level is not None and self.mask_idx is not None:
|
||||
if self.val_mask_level is not None and self.val_mask_idx is not None:
|
||||
mask = torch.tensor(
|
||||
np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0)
|
||||
np.stack(
|
||||
[self.get_seq_mask(idx, self.val_mask_level[0], self.val_mask_idx[0]) for idx in idx_list], axis=0
|
||||
)
|
||||
for i, l in enumerate(self.mask_level[1:]):
|
||||
)
|
||||
for i, l in enumerate(self.val_mask_level[1:]):
|
||||
mask = mask & torch.tensor(
|
||||
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i + 1]) for idx in idx_list], axis=0)
|
||||
np.stack([self.get_seq_mask(idx, l, self.val_mask_idx[i + 1]) for idx in idx_list], axis=0)
|
||||
)
|
||||
return mask
|
||||
else:
|
||||
|
|
70
wit/demo.py
70
wit/demo.py
|
@ -1,70 +0,0 @@
|
|||
import torch
|
||||
import sys
|
||||
from modelscope import snapshot_download
|
||||
|
||||
from wit.model.modeling_wit import QWenLMHeadModel
|
||||
from wit.model.modeling_wit import QwenRunner
|
||||
from wit.configuration import ModelConfig
|
||||
from wit.model.tokenization_qwen import QWenTokenizer
|
||||
|
||||
|
||||
from wit.model.qwen_generation_utils import (
|
||||
make_context,
|
||||
decode_tokens,
|
||||
)
|
||||
|
||||
seed = 4321
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
|
||||
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
|
||||
|
||||
config = ModelConfig()
|
||||
model = QWenLMHeadModel(config)
|
||||
|
||||
print(model)
|
||||
|
||||
tokenizer = QWenTokenizer("./model/wit_b64.tiktoken", "./model/wit_char.tiktoken")
|
||||
|
||||
sys.path.append("..")
|
||||
from tools import show
|
||||
|
||||
|
||||
def Dump_tokens_list(model):
|
||||
tokens = []
|
||||
for token in range(4096):
|
||||
decoded, response, end_reason = decode_tokens(
|
||||
[token],
|
||||
tokenizer,
|
||||
raw_text_len=0,
|
||||
context_length=0,
|
||||
errors="replace",
|
||||
)
|
||||
tokens.append(str(token).zfill(7) + ": " + repr(decoded))
|
||||
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
|
||||
|
||||
|
||||
Dump_tokens_list(model)
|
||||
|
||||
|
||||
model = model.from_pretrained(model_dir).cuda()
|
||||
|
||||
# state = model.state_dict()
|
||||
# torch.save(state, "model_params.pth")
|
||||
# model.load_state_dict(torch.load('model_params.pth'))
|
||||
|
||||
|
||||
model = model.eval()
|
||||
# model = model.train() # control by @torch.no_grad()
|
||||
|
||||
|
||||
runner = QwenRunner(model)
|
||||
|
||||
output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20)
|
||||
print(decode_tokens)
|
||||
|
||||
for i, token in enumerate(output_ids):
|
||||
de = tokenizer.decode([token])
|
||||
de = str(i + 1).zfill(3) + " : " + repr(de)
|
||||
print(de)
|
|
@ -17,7 +17,7 @@ meaning数据集是一个模仿自然语言,以及抽象表达的数据集。
|
|||
11. get_seq_mask返回一个sequence每个token在对应level是不是对应的index,level=0:最底层,index=-1:最后一个,index=0:第一个
|
||||
12. meaning_height 当前meaning的总高度
|
||||
13. meaning_weight 当前meaning的总宽度
|
||||
14. mask_level mask_idx: 表示用于训练的token的mask,mask_level=[0, 1, 2] mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||
14. val_mask_level val_mask_idx: 表示用于训练的token的mask,val_mask_level=[0, 1, 2] val_mask_idx=[0, 0, -1]表示只有是第0层第0个,而且是第1层第0个,第2层最后一个的token,才参与训练
|
||||
|
||||
|
||||
```
|
||||
|
|
|
@ -195,6 +195,18 @@ class QwenRunner:
|
|||
self.qwen = qwen
|
||||
# torch.backends.cuda.enable_flash_sdp(True)
|
||||
|
||||
@torch.no_grad()
|
||||
def ChatToken(self, input_ids):
|
||||
qwen = self.qwen
|
||||
input_ids = input_ids.to(next(qwen.parameters()).device)
|
||||
outputs, loss = self.forwardQWen(input_ids)
|
||||
next_token_scores = outputs[:, -1, :]
|
||||
|
||||
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
||||
next_token_scores = self.top_p(next_token_scores)
|
||||
next_tokens = self.sample(next_token_scores)
|
||||
return next_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def Chat(
|
||||
self,
|
||||
|
@ -214,7 +226,7 @@ class QwenRunner:
|
|||
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
input_length = input_ids.shape[1]
|
||||
while True:
|
||||
outputs = self.forwardQWen(input_ids)
|
||||
outputs, loss = self.forwardQWen(input_ids)
|
||||
next_token_scores = outputs[:, -1, :]
|
||||
|
||||
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
|
||||
|
|
|
@ -9,7 +9,7 @@ from model.modeling_wit import QWenLMHeadModel
|
|||
from configuration import ModelConfig, TrainConfig
|
||||
|
||||
|
||||
class LitModule(pl.LightningModule):
|
||||
class QwenModule(pl.LightningModule):
|
||||
def __init__(self, conf: TrainConfig = None):
|
||||
pretrained_model_dir = conf.pretrain_model_name
|
||||
learning_rate = conf.learning_rate
|
Loading…
Reference in New Issue