Add wit.
This commit is contained in:
parent
6366b52fef
commit
fe13f12327
|
@ -6,10 +6,10 @@
|
||||||
|
|
||||||
class QWenConfig:
|
class QWenConfig:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.vocab_size = 151936
|
self.vocab_size = 4096
|
||||||
self.hidden_size = 2048
|
self.hidden_size = 1024 # 1024 2048
|
||||||
self.num_hidden_layers = 24
|
self.num_hidden_layers = 12 # 12 24
|
||||||
self.num_attention_heads = 16
|
self.num_attention_heads = 8 # 8 16
|
||||||
self.emb_dropout_prob = 0.0
|
self.emb_dropout_prob = 0.0
|
||||||
self.attn_dropout_prob = 0.0
|
self.attn_dropout_prob = 0.0
|
||||||
self.layer_norm_epsilon = 1e-6
|
self.layer_norm_epsilon = 1e-6
|
||||||
|
@ -26,7 +26,7 @@ class QWenConfig:
|
||||||
self.use_dynamic_ntk = True
|
self.use_dynamic_ntk = True
|
||||||
self.use_logn_attn = True
|
self.use_logn_attn = True
|
||||||
self.use_flash_attn = "auto"
|
self.use_flash_attn = "auto"
|
||||||
self.intermediate_size = 11008
|
self.intermediate_size = 5504 # 5504 11008
|
||||||
self.no_bias = True
|
self.no_bias = True
|
||||||
self.tie_word_embeddings = False
|
self.tie_word_embeddings = False
|
||||||
self.use_cache_quantization = False
|
self.use_cache_quantization = False
|
||||||
|
@ -34,8 +34,6 @@ class QWenConfig:
|
||||||
self.softmax_in_fp32 = False
|
self.softmax_in_fp32 = False
|
||||||
|
|
||||||
self.chat_format = "chatml"
|
self.chat_format = "chatml"
|
||||||
self.eos_token_id = 151643
|
|
||||||
self.pad_token_id = 151643
|
|
||||||
self.max_window_size = 6144
|
self.max_window_size = 6144
|
||||||
self.max_new_tokens = 512
|
self.max_new_tokens = 512
|
||||||
self.do_sample = True
|
self.do_sample = True
|
||||||
|
|
43
wit/demo.py
43
wit/demo.py
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
import sys
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
from modeling_wit import QWenLMHeadModel
|
from modeling_wit import QWenLMHeadModel
|
||||||
|
@ -6,6 +7,12 @@ from modeling_wit import QwenRunner
|
||||||
from configuration_qwen import QWenConfig
|
from configuration_qwen import QWenConfig
|
||||||
from tokenization_qwen import QWenTokenizer
|
from tokenization_qwen import QWenTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
from qwen_generation_utils import (
|
||||||
|
make_context,
|
||||||
|
decode_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
seed = 4321
|
seed = 4321
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
@ -18,14 +25,46 @@ model = QWenLMHeadModel(config)
|
||||||
|
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
tokenizer = QWenTokenizer("./qwen.tiktoken")
|
tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
|
||||||
|
|
||||||
|
sys.path.append("..")
|
||||||
|
from tools import show
|
||||||
|
|
||||||
|
|
||||||
|
def Dump_tokens_list(model):
|
||||||
|
tokens = []
|
||||||
|
for token in range(4096):
|
||||||
|
decoded, response, end_reason = decode_tokens(
|
||||||
|
[token],
|
||||||
|
tokenizer,
|
||||||
|
raw_text_len=0,
|
||||||
|
context_length=0,
|
||||||
|
errors="replace",
|
||||||
|
)
|
||||||
|
tokens.append(str(token).zfill(7) + ": " + repr(decoded))
|
||||||
|
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
|
||||||
|
|
||||||
|
|
||||||
|
Dump_tokens_list(model)
|
||||||
|
|
||||||
|
|
||||||
model = model.from_pretrained(model_dir).cuda()
|
model = model.from_pretrained(model_dir).cuda()
|
||||||
|
|
||||||
|
# state = model.state_dict()
|
||||||
|
# torch.save(state, "model_params.pth")
|
||||||
|
# model.load_state_dict(torch.load('model_params.pth'))
|
||||||
|
|
||||||
|
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
# model = model.train() # control by @torch.no_grad()
|
# model = model.train() # control by @torch.no_grad()
|
||||||
|
|
||||||
|
|
||||||
runner = QwenRunner(model)
|
runner = QwenRunner(model)
|
||||||
|
|
||||||
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "")
|
output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20)
|
||||||
print(decode_tokens)
|
print(decode_tokens)
|
||||||
|
|
||||||
|
for i, token in enumerate(output_ids):
|
||||||
|
de = tokenizer.decode([token])
|
||||||
|
de = str(i + 1).zfill(3) + " : " + repr(de)
|
||||||
|
print(de)
|
||||||
|
|
|
@ -193,14 +193,18 @@ class QwenRunner:
|
||||||
tokenizer,
|
tokenizer,
|
||||||
query: str,
|
query: str,
|
||||||
query_assistant: str,
|
query_assistant: str,
|
||||||
|
gen_length=0,
|
||||||
system: str = "You are a helpful assistant.",
|
system: str = "You are a helpful assistant.",
|
||||||
history=[],
|
history=[],
|
||||||
):
|
):
|
||||||
qwen = self.qwen
|
qwen = self.qwen
|
||||||
history = copy.deepcopy(history)
|
history = copy.deepcopy(history)
|
||||||
|
self.qwen.config.pad_token_id = tokenizer.eod_id
|
||||||
|
self.qwen.config.eos_token_id = tokenizer.eod_id
|
||||||
raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
|
raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
|
||||||
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
|
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
|
||||||
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||||
|
input_length = input_ids.shape[1]
|
||||||
while True:
|
while True:
|
||||||
outputs = self.forwardQWen(input_ids)
|
outputs = self.forwardQWen(input_ids)
|
||||||
next_token_scores = outputs[:, -1, :]
|
next_token_scores = outputs[:, -1, :]
|
||||||
|
@ -212,6 +216,8 @@ class QwenRunner:
|
||||||
if finish:
|
if finish:
|
||||||
break
|
break
|
||||||
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||||
|
if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
|
||||||
|
break
|
||||||
|
|
||||||
decoded, response, end_reason = decode_tokens(
|
decoded, response, end_reason = decode_tokens(
|
||||||
input_ids[0],
|
input_ids[0],
|
||||||
|
|
|
@ -26,31 +26,21 @@ IMEND = "<|im_end|>"
|
||||||
# as the default behavior is changed to allow special tokens in
|
# as the default behavior is changed to allow special tokens in
|
||||||
# regular texts, the surface forms of special tokens need to be
|
# regular texts, the surface forms of special tokens need to be
|
||||||
# as different as possible to minimize the impact
|
# as different as possible to minimize the impact
|
||||||
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
|
||||||
# changed to use actual index to avoid misconfiguration with vocabulary expansion
|
# changed to use actual index to avoid misconfiguration with vocabulary expansion
|
||||||
SPECIAL_START_ID = 151643
|
|
||||||
SPECIAL_TOKENS = tuple(
|
|
||||||
enumerate(
|
|
||||||
(
|
|
||||||
(
|
|
||||||
ENDOFTEXT,
|
|
||||||
IMSTART,
|
|
||||||
IMEND,
|
|
||||||
)
|
|
||||||
+ EXTRAS
|
|
||||||
),
|
|
||||||
start=SPECIAL_START_ID,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
|
def _load_tiktoken_b64(tiktoken_bpe_file: str, startRank=0) -> Dict[bytes, int]:
|
||||||
with open(tiktoken_bpe_file, "rb") as f:
|
with open(tiktoken_bpe_file, "rb") as f:
|
||||||
contents = f.read()
|
contents = f.read()
|
||||||
return {
|
ll = contents.splitlines()
|
||||||
base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
|
return {base64.b64decode(token): int(rank + startRank) for rank, token in enumerate(ll)}
|
||||||
}
|
|
||||||
|
|
||||||
|
def _load_tiktoken_char(tiktoken_bpe_file: str, startRank=0) -> Dict[bytes, int]:
|
||||||
|
with open(tiktoken_bpe_file, "rb") as f:
|
||||||
|
contents = f.read()
|
||||||
|
ll = contents.splitlines()
|
||||||
|
return {token: int(rank + startRank) for rank, token in enumerate(ll)}
|
||||||
|
|
||||||
|
|
||||||
class QWenTokenizer(PreTrainedTokenizer):
|
class QWenTokenizer(PreTrainedTokenizer):
|
||||||
|
@ -60,9 +50,9 @@ class QWenTokenizer(PreTrainedTokenizer):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_file,
|
vocab_file_b64,
|
||||||
|
vocab_file_char,
|
||||||
errors="replace",
|
errors="replace",
|
||||||
extra_vocab_file=None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
@ -71,22 +61,24 @@ class QWenTokenizer(PreTrainedTokenizer):
|
||||||
# use ignore if you are in streaming inference
|
# use ignore if you are in streaming inference
|
||||||
self.errors = errors
|
self.errors = errors
|
||||||
|
|
||||||
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
|
self.mergeable_ranks = _load_tiktoken_b64(vocab_file_b64)
|
||||||
self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
|
self.mergeable_ranks.update(_load_tiktoken_char(vocab_file_char, len(self.mergeable_ranks)))
|
||||||
|
|
||||||
# try load extra vocab from file
|
special = (
|
||||||
if extra_vocab_file is not None:
|
"user",
|
||||||
used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
|
"assistant",
|
||||||
extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
|
ENDOFTEXT,
|
||||||
for token, index in extra_mergeable_ranks.items():
|
IMSTART,
|
||||||
if token in self.mergeable_ranks:
|
IMEND,
|
||||||
logger.info(f"extra token {token} exists, skipping")
|
)
|
||||||
continue
|
extras = tuple((f"<|extra_{i}|>" for i in range(4096 - len(self.mergeable_ranks) - len(special))))
|
||||||
if index in used_ids:
|
special_tokens = tuple(
|
||||||
logger.info(f"the index {index} for extra token {token} exists, skipping")
|
enumerate(
|
||||||
continue
|
(special + extras),
|
||||||
self.mergeable_ranks[token] = index
|
start=len(self.mergeable_ranks),
|
||||||
# the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
|
)
|
||||||
|
)
|
||||||
|
self.special_tokens = {token: index for index, token in special_tokens}
|
||||||
|
|
||||||
enc = tiktoken.Encoding(
|
enc = tiktoken.Encoding(
|
||||||
"Qwen",
|
"Qwen",
|
||||||
|
@ -95,7 +87,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
||||||
special_tokens=self.special_tokens,
|
special_tokens=self.special_tokens,
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
|
len(self.mergeable_ranks) + len(self.special_tokens) <= enc.n_vocab
|
||||||
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
||||||
|
|
||||||
self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
|
self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
|
||||||
|
@ -153,8 +145,6 @@ class QWenTokenizer(PreTrainedTokenizer):
|
||||||
raise ValueError("Adding regular tokens is not supported")
|
raise ValueError("Adding regular tokens is not supported")
|
||||||
for token in new_tokens:
|
for token in new_tokens:
|
||||||
surface_form = token.content if isinstance(token, AddedToken) else token
|
surface_form = token.content if isinstance(token, AddedToken) else token
|
||||||
if surface_form not in SPECIAL_TOKENS_SET:
|
|
||||||
raise ValueError("Adding unknown special tokens is not supported")
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
|
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
|
||||||
|
|
|
@ -0,0 +1,256 @@
|
||||||
|
IQ==
|
||||||
|
Ig==
|
||||||
|
Iw==
|
||||||
|
JA==
|
||||||
|
JQ==
|
||||||
|
Jg==
|
||||||
|
Jw==
|
||||||
|
KA==
|
||||||
|
KQ==
|
||||||
|
Kg==
|
||||||
|
Kw==
|
||||||
|
LA==
|
||||||
|
LQ==
|
||||||
|
Lg==
|
||||||
|
Lw==
|
||||||
|
MA==
|
||||||
|
MQ==
|
||||||
|
Mg==
|
||||||
|
Mw==
|
||||||
|
NA==
|
||||||
|
NQ==
|
||||||
|
Ng==
|
||||||
|
Nw==
|
||||||
|
OA==
|
||||||
|
OQ==
|
||||||
|
Og==
|
||||||
|
Ow==
|
||||||
|
PA==
|
||||||
|
PQ==
|
||||||
|
Pg==
|
||||||
|
Pw==
|
||||||
|
QA==
|
||||||
|
QQ==
|
||||||
|
Qg==
|
||||||
|
Qw==
|
||||||
|
RA==
|
||||||
|
RQ==
|
||||||
|
Rg==
|
||||||
|
Rw==
|
||||||
|
SA==
|
||||||
|
SQ==
|
||||||
|
Sg==
|
||||||
|
Sw==
|
||||||
|
TA==
|
||||||
|
TQ==
|
||||||
|
Tg==
|
||||||
|
Tw==
|
||||||
|
UA==
|
||||||
|
UQ==
|
||||||
|
Ug==
|
||||||
|
Uw==
|
||||||
|
VA==
|
||||||
|
VQ==
|
||||||
|
Vg==
|
||||||
|
Vw==
|
||||||
|
WA==
|
||||||
|
WQ==
|
||||||
|
Wg==
|
||||||
|
Ww==
|
||||||
|
XA==
|
||||||
|
XQ==
|
||||||
|
Xg==
|
||||||
|
Xw==
|
||||||
|
YA==
|
||||||
|
YQ==
|
||||||
|
Yg==
|
||||||
|
Yw==
|
||||||
|
ZA==
|
||||||
|
ZQ==
|
||||||
|
Zg==
|
||||||
|
Zw==
|
||||||
|
aA==
|
||||||
|
aQ==
|
||||||
|
ag==
|
||||||
|
aw==
|
||||||
|
bA==
|
||||||
|
bQ==
|
||||||
|
bg==
|
||||||
|
bw==
|
||||||
|
cA==
|
||||||
|
cQ==
|
||||||
|
cg==
|
||||||
|
cw==
|
||||||
|
dA==
|
||||||
|
dQ==
|
||||||
|
dg==
|
||||||
|
dw==
|
||||||
|
eA==
|
||||||
|
eQ==
|
||||||
|
eg==
|
||||||
|
ew==
|
||||||
|
fA==
|
||||||
|
fQ==
|
||||||
|
fg==
|
||||||
|
oQ==
|
||||||
|
og==
|
||||||
|
ow==
|
||||||
|
pA==
|
||||||
|
pQ==
|
||||||
|
pg==
|
||||||
|
pw==
|
||||||
|
qA==
|
||||||
|
qQ==
|
||||||
|
qg==
|
||||||
|
qw==
|
||||||
|
rA==
|
||||||
|
rg==
|
||||||
|
rw==
|
||||||
|
sA==
|
||||||
|
sQ==
|
||||||
|
sg==
|
||||||
|
sw==
|
||||||
|
tA==
|
||||||
|
tQ==
|
||||||
|
tg==
|
||||||
|
tw==
|
||||||
|
uA==
|
||||||
|
uQ==
|
||||||
|
ug==
|
||||||
|
uw==
|
||||||
|
vA==
|
||||||
|
vQ==
|
||||||
|
vg==
|
||||||
|
vw==
|
||||||
|
wA==
|
||||||
|
wQ==
|
||||||
|
wg==
|
||||||
|
ww==
|
||||||
|
xA==
|
||||||
|
xQ==
|
||||||
|
xg==
|
||||||
|
xw==
|
||||||
|
yA==
|
||||||
|
yQ==
|
||||||
|
yg==
|
||||||
|
yw==
|
||||||
|
zA==
|
||||||
|
zQ==
|
||||||
|
zg==
|
||||||
|
zw==
|
||||||
|
0A==
|
||||||
|
0Q==
|
||||||
|
0g==
|
||||||
|
0w==
|
||||||
|
1A==
|
||||||
|
1Q==
|
||||||
|
1g==
|
||||||
|
1w==
|
||||||
|
2A==
|
||||||
|
2Q==
|
||||||
|
2g==
|
||||||
|
2w==
|
||||||
|
3A==
|
||||||
|
3Q==
|
||||||
|
3g==
|
||||||
|
3w==
|
||||||
|
4A==
|
||||||
|
4Q==
|
||||||
|
4g==
|
||||||
|
4w==
|
||||||
|
5A==
|
||||||
|
5Q==
|
||||||
|
5g==
|
||||||
|
5w==
|
||||||
|
6A==
|
||||||
|
6Q==
|
||||||
|
6g==
|
||||||
|
6w==
|
||||||
|
7A==
|
||||||
|
7Q==
|
||||||
|
7g==
|
||||||
|
7w==
|
||||||
|
8A==
|
||||||
|
8Q==
|
||||||
|
8g==
|
||||||
|
8w==
|
||||||
|
9A==
|
||||||
|
9Q==
|
||||||
|
9g==
|
||||||
|
9w==
|
||||||
|
+A==
|
||||||
|
+Q==
|
||||||
|
+g==
|
||||||
|
+w==
|
||||||
|
/A==
|
||||||
|
/Q==
|
||||||
|
/g==
|
||||||
|
/w==
|
||||||
|
AA==
|
||||||
|
AQ==
|
||||||
|
Ag==
|
||||||
|
Aw==
|
||||||
|
BA==
|
||||||
|
BQ==
|
||||||
|
Bg==
|
||||||
|
Bw==
|
||||||
|
CA==
|
||||||
|
CQ==
|
||||||
|
Cg==
|
||||||
|
Cw==
|
||||||
|
DA==
|
||||||
|
DQ==
|
||||||
|
Dg==
|
||||||
|
Dw==
|
||||||
|
EA==
|
||||||
|
EQ==
|
||||||
|
Eg==
|
||||||
|
Ew==
|
||||||
|
FA==
|
||||||
|
FQ==
|
||||||
|
Fg==
|
||||||
|
Fw==
|
||||||
|
GA==
|
||||||
|
GQ==
|
||||||
|
Gg==
|
||||||
|
Gw==
|
||||||
|
HA==
|
||||||
|
HQ==
|
||||||
|
Hg==
|
||||||
|
Hw==
|
||||||
|
IA==
|
||||||
|
fw==
|
||||||
|
gA==
|
||||||
|
gQ==
|
||||||
|
gg==
|
||||||
|
gw==
|
||||||
|
hA==
|
||||||
|
hQ==
|
||||||
|
hg==
|
||||||
|
hw==
|
||||||
|
iA==
|
||||||
|
iQ==
|
||||||
|
ig==
|
||||||
|
iw==
|
||||||
|
jA==
|
||||||
|
jQ==
|
||||||
|
jg==
|
||||||
|
jw==
|
||||||
|
kA==
|
||||||
|
kQ==
|
||||||
|
kg==
|
||||||
|
kw==
|
||||||
|
lA==
|
||||||
|
lQ==
|
||||||
|
lg==
|
||||||
|
lw==
|
||||||
|
mA==
|
||||||
|
mQ==
|
||||||
|
mg==
|
||||||
|
mw==
|
||||||
|
nA==
|
||||||
|
nQ==
|
||||||
|
ng==
|
||||||
|
nw==
|
||||||
|
oA==
|
||||||
|
rQ==
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue