Add wit.
This commit is contained in:
parent
6366b52fef
commit
fe13f12327
|
@ -6,10 +6,10 @@
|
|||
|
||||
class QWenConfig:
|
||||
def __init__(self):
|
||||
self.vocab_size = 151936
|
||||
self.hidden_size = 2048
|
||||
self.num_hidden_layers = 24
|
||||
self.num_attention_heads = 16
|
||||
self.vocab_size = 4096
|
||||
self.hidden_size = 1024 # 1024 2048
|
||||
self.num_hidden_layers = 12 # 12 24
|
||||
self.num_attention_heads = 8 # 8 16
|
||||
self.emb_dropout_prob = 0.0
|
||||
self.attn_dropout_prob = 0.0
|
||||
self.layer_norm_epsilon = 1e-6
|
||||
|
@ -26,7 +26,7 @@ class QWenConfig:
|
|||
self.use_dynamic_ntk = True
|
||||
self.use_logn_attn = True
|
||||
self.use_flash_attn = "auto"
|
||||
self.intermediate_size = 11008
|
||||
self.intermediate_size = 5504 # 5504 11008
|
||||
self.no_bias = True
|
||||
self.tie_word_embeddings = False
|
||||
self.use_cache_quantization = False
|
||||
|
@ -34,8 +34,6 @@ class QWenConfig:
|
|||
self.softmax_in_fp32 = False
|
||||
|
||||
self.chat_format = "chatml"
|
||||
self.eos_token_id = 151643
|
||||
self.pad_token_id = 151643
|
||||
self.max_window_size = 6144
|
||||
self.max_new_tokens = 512
|
||||
self.do_sample = True
|
||||
|
|
43
wit/demo.py
43
wit/demo.py
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
import sys
|
||||
from modelscope import snapshot_download
|
||||
|
||||
from modeling_wit import QWenLMHeadModel
|
||||
|
@ -6,6 +7,12 @@ from modeling_wit import QwenRunner
|
|||
from configuration_qwen import QWenConfig
|
||||
from tokenization_qwen import QWenTokenizer
|
||||
|
||||
|
||||
from qwen_generation_utils import (
|
||||
make_context,
|
||||
decode_tokens,
|
||||
)
|
||||
|
||||
seed = 4321
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
@ -18,14 +25,46 @@ model = QWenLMHeadModel(config)
|
|||
|
||||
print(model)
|
||||
|
||||
tokenizer = QWenTokenizer("./qwen.tiktoken")
|
||||
tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
|
||||
|
||||
sys.path.append("..")
|
||||
from tools import show
|
||||
|
||||
|
||||
def Dump_tokens_list(model):
|
||||
tokens = []
|
||||
for token in range(4096):
|
||||
decoded, response, end_reason = decode_tokens(
|
||||
[token],
|
||||
tokenizer,
|
||||
raw_text_len=0,
|
||||
context_length=0,
|
||||
errors="replace",
|
||||
)
|
||||
tokens.append(str(token).zfill(7) + ": " + repr(decoded))
|
||||
show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
|
||||
|
||||
|
||||
Dump_tokens_list(model)
|
||||
|
||||
|
||||
model = model.from_pretrained(model_dir).cuda()
|
||||
|
||||
# state = model.state_dict()
|
||||
# torch.save(state, "model_params.pth")
|
||||
# model.load_state_dict(torch.load('model_params.pth'))
|
||||
|
||||
|
||||
model = model.eval()
|
||||
# model = model.train() # control by @torch.no_grad()
|
||||
|
||||
|
||||
runner = QwenRunner(model)
|
||||
|
||||
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "")
|
||||
output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20)
|
||||
print(decode_tokens)
|
||||
|
||||
for i, token in enumerate(output_ids):
|
||||
de = tokenizer.decode([token])
|
||||
de = str(i + 1).zfill(3) + " : " + repr(de)
|
||||
print(de)
|
||||
|
|
|
@ -193,14 +193,18 @@ class QwenRunner:
|
|||
tokenizer,
|
||||
query: str,
|
||||
query_assistant: str,
|
||||
gen_length=0,
|
||||
system: str = "You are a helpful assistant.",
|
||||
history=[],
|
||||
):
|
||||
qwen = self.qwen
|
||||
history = copy.deepcopy(history)
|
||||
self.qwen.config.pad_token_id = tokenizer.eod_id
|
||||
self.qwen.config.eos_token_id = tokenizer.eod_id
|
||||
raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
|
||||
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
|
||||
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||
input_length = input_ids.shape[1]
|
||||
while True:
|
||||
outputs = self.forwardQWen(input_ids)
|
||||
next_token_scores = outputs[:, -1, :]
|
||||
|
@ -212,6 +216,8 @@ class QwenRunner:
|
|||
if finish:
|
||||
break
|
||||
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
|
||||
break
|
||||
|
||||
decoded, response, end_reason = decode_tokens(
|
||||
input_ids[0],
|
||||
|
|
|
@ -26,31 +26,21 @@ IMEND = "<|im_end|>"
|
|||
# as the default behavior is changed to allow special tokens in
|
||||
# regular texts, the surface forms of special tokens need to be
|
||||
# as different as possible to minimize the impact
|
||||
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
|
||||
# changed to use actual index to avoid misconfiguration with vocabulary expansion
|
||||
SPECIAL_START_ID = 151643
|
||||
SPECIAL_TOKENS = tuple(
|
||||
enumerate(
|
||||
(
|
||||
(
|
||||
ENDOFTEXT,
|
||||
IMSTART,
|
||||
IMEND,
|
||||
)
|
||||
+ EXTRAS
|
||||
),
|
||||
start=SPECIAL_START_ID,
|
||||
)
|
||||
)
|
||||
SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
|
||||
|
||||
|
||||
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
|
||||
def _load_tiktoken_b64(tiktoken_bpe_file: str, startRank=0) -> Dict[bytes, int]:
|
||||
with open(tiktoken_bpe_file, "rb") as f:
|
||||
contents = f.read()
|
||||
return {
|
||||
base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
|
||||
}
|
||||
ll = contents.splitlines()
|
||||
return {base64.b64decode(token): int(rank + startRank) for rank, token in enumerate(ll)}
|
||||
|
||||
|
||||
def _load_tiktoken_char(tiktoken_bpe_file: str, startRank=0) -> Dict[bytes, int]:
|
||||
with open(tiktoken_bpe_file, "rb") as f:
|
||||
contents = f.read()
|
||||
ll = contents.splitlines()
|
||||
return {token: int(rank + startRank) for rank, token in enumerate(ll)}
|
||||
|
||||
|
||||
class QWenTokenizer(PreTrainedTokenizer):
|
||||
|
@ -60,9 +50,9 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
vocab_file_b64,
|
||||
vocab_file_char,
|
||||
errors="replace",
|
||||
extra_vocab_file=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
@ -71,22 +61,24 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|||
# use ignore if you are in streaming inference
|
||||
self.errors = errors
|
||||
|
||||
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
|
||||
self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
|
||||
self.mergeable_ranks = _load_tiktoken_b64(vocab_file_b64)
|
||||
self.mergeable_ranks.update(_load_tiktoken_char(vocab_file_char, len(self.mergeable_ranks)))
|
||||
|
||||
# try load extra vocab from file
|
||||
if extra_vocab_file is not None:
|
||||
used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
|
||||
extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
|
||||
for token, index in extra_mergeable_ranks.items():
|
||||
if token in self.mergeable_ranks:
|
||||
logger.info(f"extra token {token} exists, skipping")
|
||||
continue
|
||||
if index in used_ids:
|
||||
logger.info(f"the index {index} for extra token {token} exists, skipping")
|
||||
continue
|
||||
self.mergeable_ranks[token] = index
|
||||
# the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
|
||||
special = (
|
||||
"user",
|
||||
"assistant",
|
||||
ENDOFTEXT,
|
||||
IMSTART,
|
||||
IMEND,
|
||||
)
|
||||
extras = tuple((f"<|extra_{i}|>" for i in range(4096 - len(self.mergeable_ranks) - len(special))))
|
||||
special_tokens = tuple(
|
||||
enumerate(
|
||||
(special + extras),
|
||||
start=len(self.mergeable_ranks),
|
||||
)
|
||||
)
|
||||
self.special_tokens = {token: index for index, token in special_tokens}
|
||||
|
||||
enc = tiktoken.Encoding(
|
||||
"Qwen",
|
||||
|
@ -95,7 +87,7 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|||
special_tokens=self.special_tokens,
|
||||
)
|
||||
assert (
|
||||
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
|
||||
len(self.mergeable_ranks) + len(self.special_tokens) <= enc.n_vocab
|
||||
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
|
||||
|
||||
self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
|
||||
|
@ -153,8 +145,6 @@ class QWenTokenizer(PreTrainedTokenizer):
|
|||
raise ValueError("Adding regular tokens is not supported")
|
||||
for token in new_tokens:
|
||||
surface_form = token.content if isinstance(token, AddedToken) else token
|
||||
if surface_form not in SPECIAL_TOKENS_SET:
|
||||
raise ValueError("Adding unknown special tokens is not supported")
|
||||
return 0
|
||||
|
||||
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
IQ==
|
||||
Ig==
|
||||
Iw==
|
||||
JA==
|
||||
JQ==
|
||||
Jg==
|
||||
Jw==
|
||||
KA==
|
||||
KQ==
|
||||
Kg==
|
||||
Kw==
|
||||
LA==
|
||||
LQ==
|
||||
Lg==
|
||||
Lw==
|
||||
MA==
|
||||
MQ==
|
||||
Mg==
|
||||
Mw==
|
||||
NA==
|
||||
NQ==
|
||||
Ng==
|
||||
Nw==
|
||||
OA==
|
||||
OQ==
|
||||
Og==
|
||||
Ow==
|
||||
PA==
|
||||
PQ==
|
||||
Pg==
|
||||
Pw==
|
||||
QA==
|
||||
QQ==
|
||||
Qg==
|
||||
Qw==
|
||||
RA==
|
||||
RQ==
|
||||
Rg==
|
||||
Rw==
|
||||
SA==
|
||||
SQ==
|
||||
Sg==
|
||||
Sw==
|
||||
TA==
|
||||
TQ==
|
||||
Tg==
|
||||
Tw==
|
||||
UA==
|
||||
UQ==
|
||||
Ug==
|
||||
Uw==
|
||||
VA==
|
||||
VQ==
|
||||
Vg==
|
||||
Vw==
|
||||
WA==
|
||||
WQ==
|
||||
Wg==
|
||||
Ww==
|
||||
XA==
|
||||
XQ==
|
||||
Xg==
|
||||
Xw==
|
||||
YA==
|
||||
YQ==
|
||||
Yg==
|
||||
Yw==
|
||||
ZA==
|
||||
ZQ==
|
||||
Zg==
|
||||
Zw==
|
||||
aA==
|
||||
aQ==
|
||||
ag==
|
||||
aw==
|
||||
bA==
|
||||
bQ==
|
||||
bg==
|
||||
bw==
|
||||
cA==
|
||||
cQ==
|
||||
cg==
|
||||
cw==
|
||||
dA==
|
||||
dQ==
|
||||
dg==
|
||||
dw==
|
||||
eA==
|
||||
eQ==
|
||||
eg==
|
||||
ew==
|
||||
fA==
|
||||
fQ==
|
||||
fg==
|
||||
oQ==
|
||||
og==
|
||||
ow==
|
||||
pA==
|
||||
pQ==
|
||||
pg==
|
||||
pw==
|
||||
qA==
|
||||
qQ==
|
||||
qg==
|
||||
qw==
|
||||
rA==
|
||||
rg==
|
||||
rw==
|
||||
sA==
|
||||
sQ==
|
||||
sg==
|
||||
sw==
|
||||
tA==
|
||||
tQ==
|
||||
tg==
|
||||
tw==
|
||||
uA==
|
||||
uQ==
|
||||
ug==
|
||||
uw==
|
||||
vA==
|
||||
vQ==
|
||||
vg==
|
||||
vw==
|
||||
wA==
|
||||
wQ==
|
||||
wg==
|
||||
ww==
|
||||
xA==
|
||||
xQ==
|
||||
xg==
|
||||
xw==
|
||||
yA==
|
||||
yQ==
|
||||
yg==
|
||||
yw==
|
||||
zA==
|
||||
zQ==
|
||||
zg==
|
||||
zw==
|
||||
0A==
|
||||
0Q==
|
||||
0g==
|
||||
0w==
|
||||
1A==
|
||||
1Q==
|
||||
1g==
|
||||
1w==
|
||||
2A==
|
||||
2Q==
|
||||
2g==
|
||||
2w==
|
||||
3A==
|
||||
3Q==
|
||||
3g==
|
||||
3w==
|
||||
4A==
|
||||
4Q==
|
||||
4g==
|
||||
4w==
|
||||
5A==
|
||||
5Q==
|
||||
5g==
|
||||
5w==
|
||||
6A==
|
||||
6Q==
|
||||
6g==
|
||||
6w==
|
||||
7A==
|
||||
7Q==
|
||||
7g==
|
||||
7w==
|
||||
8A==
|
||||
8Q==
|
||||
8g==
|
||||
8w==
|
||||
9A==
|
||||
9Q==
|
||||
9g==
|
||||
9w==
|
||||
+A==
|
||||
+Q==
|
||||
+g==
|
||||
+w==
|
||||
/A==
|
||||
/Q==
|
||||
/g==
|
||||
/w==
|
||||
AA==
|
||||
AQ==
|
||||
Ag==
|
||||
Aw==
|
||||
BA==
|
||||
BQ==
|
||||
Bg==
|
||||
Bw==
|
||||
CA==
|
||||
CQ==
|
||||
Cg==
|
||||
Cw==
|
||||
DA==
|
||||
DQ==
|
||||
Dg==
|
||||
Dw==
|
||||
EA==
|
||||
EQ==
|
||||
Eg==
|
||||
Ew==
|
||||
FA==
|
||||
FQ==
|
||||
Fg==
|
||||
Fw==
|
||||
GA==
|
||||
GQ==
|
||||
Gg==
|
||||
Gw==
|
||||
HA==
|
||||
HQ==
|
||||
Hg==
|
||||
Hw==
|
||||
IA==
|
||||
fw==
|
||||
gA==
|
||||
gQ==
|
||||
gg==
|
||||
gw==
|
||||
hA==
|
||||
hQ==
|
||||
hg==
|
||||
hw==
|
||||
iA==
|
||||
iQ==
|
||||
ig==
|
||||
iw==
|
||||
jA==
|
||||
jQ==
|
||||
jg==
|
||||
jw==
|
||||
kA==
|
||||
kQ==
|
||||
kg==
|
||||
kw==
|
||||
lA==
|
||||
lQ==
|
||||
lg==
|
||||
lw==
|
||||
mA==
|
||||
mQ==
|
||||
mg==
|
||||
mw==
|
||||
nA==
|
||||
nQ==
|
||||
ng==
|
||||
nw==
|
||||
oA==
|
||||
rQ==
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue