This commit is contained in:
Colin 2024-02-06 14:08:45 +08:00
parent 6366b52fef
commit fe13f12327
6 changed files with 3838 additions and 49 deletions

View File

@ -6,10 +6,10 @@
class QWenConfig: class QWenConfig:
def __init__(self): def __init__(self):
self.vocab_size = 151936 self.vocab_size = 4096
self.hidden_size = 2048 self.hidden_size = 1024 # 1024 2048
self.num_hidden_layers = 24 self.num_hidden_layers = 12 # 12 24
self.num_attention_heads = 16 self.num_attention_heads = 8 # 8 16
self.emb_dropout_prob = 0.0 self.emb_dropout_prob = 0.0
self.attn_dropout_prob = 0.0 self.attn_dropout_prob = 0.0
self.layer_norm_epsilon = 1e-6 self.layer_norm_epsilon = 1e-6
@ -26,7 +26,7 @@ class QWenConfig:
self.use_dynamic_ntk = True self.use_dynamic_ntk = True
self.use_logn_attn = True self.use_logn_attn = True
self.use_flash_attn = "auto" self.use_flash_attn = "auto"
self.intermediate_size = 11008 self.intermediate_size = 5504 # 5504 11008
self.no_bias = True self.no_bias = True
self.tie_word_embeddings = False self.tie_word_embeddings = False
self.use_cache_quantization = False self.use_cache_quantization = False
@ -34,8 +34,6 @@ class QWenConfig:
self.softmax_in_fp32 = False self.softmax_in_fp32 = False
self.chat_format = "chatml" self.chat_format = "chatml"
self.eos_token_id = 151643
self.pad_token_id = 151643
self.max_window_size = 6144 self.max_window_size = 6144
self.max_new_tokens = 512 self.max_new_tokens = 512
self.do_sample = True self.do_sample = True

View File

@ -1,4 +1,5 @@
import torch import torch
import sys
from modelscope import snapshot_download from modelscope import snapshot_download
from modeling_wit import QWenLMHeadModel from modeling_wit import QWenLMHeadModel
@ -6,6 +7,12 @@ from modeling_wit import QwenRunner
from configuration_qwen import QWenConfig from configuration_qwen import QWenConfig
from tokenization_qwen import QWenTokenizer from tokenization_qwen import QWenTokenizer
from qwen_generation_utils import (
make_context,
decode_tokens,
)
seed = 4321 seed = 4321
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
@ -18,14 +25,46 @@ model = QWenLMHeadModel(config)
print(model) print(model)
tokenizer = QWenTokenizer("./qwen.tiktoken") tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
sys.path.append("..")
from tools import show
def Dump_tokens_list(model):
tokens = []
for token in range(4096):
decoded, response, end_reason = decode_tokens(
[token],
tokenizer,
raw_text_len=0,
context_length=0,
errors="replace",
)
tokens.append(str(token).zfill(7) + ": " + repr(decoded))
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
Dump_tokens_list(model)
model = model.from_pretrained(model_dir).cuda() model = model.from_pretrained(model_dir).cuda()
# state = model.state_dict()
# torch.save(state, "model_params.pth")
# model.load_state_dict(torch.load('model_params.pth'))
model = model.eval() model = model.eval()
# model = model.train() # control by @torch.no_grad() # model = model.train() # control by @torch.no_grad()
runner = QwenRunner(model) runner = QwenRunner(model)
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "") output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20)
print(decode_tokens) print(decode_tokens)
for i, token in enumerate(output_ids):
de = tokenizer.decode([token])
de = str(i + 1).zfill(3) + " : " + repr(de)
print(de)

View File

@ -193,14 +193,18 @@ class QwenRunner:
tokenizer, tokenizer,
query: str, query: str,
query_assistant: str, query_assistant: str,
gen_length=0,
system: str = "You are a helpful assistant.", system: str = "You are a helpful assistant.",
history=[], history=[],
): ):
qwen = self.qwen qwen = self.qwen
history = copy.deepcopy(history) history = copy.deepcopy(history)
self.qwen.config.pad_token_id = tokenizer.eod_id
self.qwen.config.eos_token_id = tokenizer.eod_id
raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system) raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device) input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
input_length = input_ids.shape[1]
while True: while True:
outputs = self.forwardQWen(input_ids) outputs = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :] next_token_scores = outputs[:, -1, :]
@ -212,6 +216,8 @@ class QwenRunner:
if finish: if finish:
break break
input_ids = torch.cat([input_ids, next_tokens], dim=-1) input_ids = torch.cat([input_ids, next_tokens], dim=-1)
if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
break
decoded, response, end_reason = decode_tokens( decoded, response, end_reason = decode_tokens(
input_ids[0], input_ids[0],

View File

@ -26,31 +26,21 @@ IMEND = "<|im_end|>"
# as the default behavior is changed to allow special tokens in # as the default behavior is changed to allow special tokens in
# regular texts, the surface forms of special tokens need to be # regular texts, the surface forms of special tokens need to be
# as different as possible to minimize the impact # as different as possible to minimize the impact
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
# changed to use actual index to avoid misconfiguration with vocabulary expansion # changed to use actual index to avoid misconfiguration with vocabulary expansion
SPECIAL_START_ID = 151643
SPECIAL_TOKENS = tuple(
enumerate(
(
(
ENDOFTEXT,
IMSTART,
IMEND,
)
+ EXTRAS
),
start=SPECIAL_START_ID,
)
)
SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: def _load_tiktoken_b64(tiktoken_bpe_file: str, startRank=0) -> Dict[bytes, int]:
with open(tiktoken_bpe_file, "rb") as f: with open(tiktoken_bpe_file, "rb") as f:
contents = f.read() contents = f.read()
return { ll = contents.splitlines()
base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) return {base64.b64decode(token): int(rank + startRank) for rank, token in enumerate(ll)}
}
def _load_tiktoken_char(tiktoken_bpe_file: str, startRank=0) -> Dict[bytes, int]:
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
ll = contents.splitlines()
return {token: int(rank + startRank) for rank, token in enumerate(ll)}
class QWenTokenizer(PreTrainedTokenizer): class QWenTokenizer(PreTrainedTokenizer):
@ -60,9 +50,9 @@ class QWenTokenizer(PreTrainedTokenizer):
def __init__( def __init__(
self, self,
vocab_file, vocab_file_b64,
vocab_file_char,
errors="replace", errors="replace",
extra_vocab_file=None,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -71,22 +61,24 @@ class QWenTokenizer(PreTrainedTokenizer):
# use ignore if you are in streaming inference # use ignore if you are in streaming inference
self.errors = errors self.errors = errors
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int] self.mergeable_ranks = _load_tiktoken_b64(vocab_file_b64)
self.special_tokens = {token: index for index, token in SPECIAL_TOKENS} self.mergeable_ranks.update(_load_tiktoken_char(vocab_file_char, len(self.mergeable_ranks)))
# try load extra vocab from file special = (
if extra_vocab_file is not None: "user",
used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values()) "assistant",
extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file) ENDOFTEXT,
for token, index in extra_mergeable_ranks.items(): IMSTART,
if token in self.mergeable_ranks: IMEND,
logger.info(f"extra token {token} exists, skipping") )
continue extras = tuple((f"<|extra_{i}|>" for i in range(4096 - len(self.mergeable_ranks) - len(special))))
if index in used_ids: special_tokens = tuple(
logger.info(f"the index {index} for extra token {token} exists, skipping") enumerate(
continue (special + extras),
self.mergeable_ranks[token] = index start=len(self.mergeable_ranks),
# the index may be sparse after this, but don't worry tiktoken.Encoding will handle this )
)
self.special_tokens = {token: index for index, token in special_tokens}
enc = tiktoken.Encoding( enc = tiktoken.Encoding(
"Qwen", "Qwen",
@ -95,7 +87,7 @@ class QWenTokenizer(PreTrainedTokenizer):
special_tokens=self.special_tokens, special_tokens=self.special_tokens,
) )
assert ( assert (
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab len(self.mergeable_ranks) + len(self.special_tokens) <= enc.n_vocab
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str] self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
@ -153,8 +145,6 @@ class QWenTokenizer(PreTrainedTokenizer):
raise ValueError("Adding regular tokens is not supported") raise ValueError("Adding regular tokens is not supported")
for token in new_tokens: for token in new_tokens:
surface_form = token.content if isinstance(token, AddedToken) else token surface_form = token.content if isinstance(token, AddedToken) else token
if surface_form not in SPECIAL_TOKENS_SET:
raise ValueError("Adding unknown special tokens is not supported")
return 0 return 0
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:

256
wit/wit_b64.tiktoken Normal file
View File

@ -0,0 +1,256 @@
IQ==
Ig==
Iw==
JA==
JQ==
Jg==
Jw==
KA==
KQ==
Kg==
Kw==
LA==
LQ==
Lg==
Lw==
MA==
MQ==
Mg==
Mw==
NA==
NQ==
Ng==
Nw==
OA==
OQ==
Og==
Ow==
PA==
PQ==
Pg==
Pw==
QA==
QQ==
Qg==
Qw==
RA==
RQ==
Rg==
Rw==
SA==
SQ==
Sg==
Sw==
TA==
TQ==
Tg==
Tw==
UA==
UQ==
Ug==
Uw==
VA==
VQ==
Vg==
Vw==
WA==
WQ==
Wg==
Ww==
XA==
XQ==
Xg==
Xw==
YA==
YQ==
Yg==
Yw==
ZA==
ZQ==
Zg==
Zw==
aA==
aQ==
ag==
aw==
bA==
bQ==
bg==
bw==
cA==
cQ==
cg==
cw==
dA==
dQ==
dg==
dw==
eA==
eQ==
eg==
ew==
fA==
fQ==
fg==
oQ==
og==
ow==
pA==
pQ==
pg==
pw==
qA==
qQ==
qg==
qw==
rA==
rg==
rw==
sA==
sQ==
sg==
sw==
tA==
tQ==
tg==
tw==
uA==
uQ==
ug==
uw==
vA==
vQ==
vg==
vw==
wA==
wQ==
wg==
ww==
xA==
xQ==
xg==
xw==
yA==
yQ==
yg==
yw==
zA==
zQ==
zg==
zw==
0A==
0Q==
0g==
0w==
1A==
1Q==
1g==
1w==
2A==
2Q==
2g==
2w==
3A==
3Q==
3g==
3w==
4A==
4Q==
4g==
4w==
5A==
5Q==
5g==
5w==
6A==
6Q==
6g==
6w==
7A==
7Q==
7g==
7w==
8A==
8Q==
8g==
8w==
9A==
9Q==
9g==
9w==
+A==
+Q==
+g==
+w==
/A==
/Q==
/g==
/w==
AA==
AQ==
Ag==
Aw==
BA==
BQ==
Bg==
Bw==
CA==
CQ==
Cg==
Cw==
DA==
DQ==
Dg==
Dw==
EA==
EQ==
Eg==
Ew==
FA==
FQ==
Fg==
Fw==
GA==
GQ==
Gg==
Gw==
HA==
HQ==
Hg==
Hw==
IA==
fw==
gA==
gQ==
gg==
gw==
hA==
hQ==
hg==
hw==
iA==
iQ==
ig==
iw==
jA==
jQ==
jg==
jw==
kA==
kQ==
kg==
kw==
lA==
lQ==
lg==
lw==
mA==
mQ==
mg==
mw==
nA==
nQ==
ng==
nw==
oA==
rQ==

3500
wit/wit_char.tiktoken Normal file

File diff suppressed because it is too large Load Diff