Add wit.
This commit is contained in:
		
							parent
							
								
									6366b52fef
								
							
						
					
					
						commit
						fe13f12327
					
				| 
						 | 
				
			
			@ -6,10 +6,10 @@
 | 
			
		|||
 | 
			
		||||
class QWenConfig:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.vocab_size = 151936
 | 
			
		||||
        self.hidden_size = 2048
 | 
			
		||||
        self.num_hidden_layers = 24
 | 
			
		||||
        self.num_attention_heads = 16
 | 
			
		||||
        self.vocab_size = 4096
 | 
			
		||||
        self.hidden_size = 1024  # 1024 2048
 | 
			
		||||
        self.num_hidden_layers = 12  # 12 24
 | 
			
		||||
        self.num_attention_heads = 8  # 8 16
 | 
			
		||||
        self.emb_dropout_prob = 0.0
 | 
			
		||||
        self.attn_dropout_prob = 0.0
 | 
			
		||||
        self.layer_norm_epsilon = 1e-6
 | 
			
		||||
| 
						 | 
				
			
			@ -26,7 +26,7 @@ class QWenConfig:
 | 
			
		|||
        self.use_dynamic_ntk = True
 | 
			
		||||
        self.use_logn_attn = True
 | 
			
		||||
        self.use_flash_attn = "auto"
 | 
			
		||||
        self.intermediate_size = 11008
 | 
			
		||||
        self.intermediate_size = 5504  # 5504 11008
 | 
			
		||||
        self.no_bias = True
 | 
			
		||||
        self.tie_word_embeddings = False
 | 
			
		||||
        self.use_cache_quantization = False
 | 
			
		||||
| 
						 | 
				
			
			@ -34,8 +34,6 @@ class QWenConfig:
 | 
			
		|||
        self.softmax_in_fp32 = False
 | 
			
		||||
 | 
			
		||||
        self.chat_format = "chatml"
 | 
			
		||||
        self.eos_token_id = 151643
 | 
			
		||||
        self.pad_token_id = 151643
 | 
			
		||||
        self.max_window_size = 6144
 | 
			
		||||
        self.max_new_tokens = 512
 | 
			
		||||
        self.do_sample = True
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										43
									
								
								wit/demo.py
								
								
								
								
							
							
						
						
									
										43
									
								
								wit/demo.py
								
								
								
								
							| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
import torch
 | 
			
		||||
import sys
 | 
			
		||||
from modelscope import snapshot_download
 | 
			
		||||
 | 
			
		||||
from modeling_wit import QWenLMHeadModel
 | 
			
		||||
| 
						 | 
				
			
			@ -6,6 +7,12 @@ from modeling_wit import QwenRunner
 | 
			
		|||
from configuration_qwen import QWenConfig
 | 
			
		||||
from tokenization_qwen import QWenTokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from qwen_generation_utils import (
 | 
			
		||||
    make_context,
 | 
			
		||||
    decode_tokens,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
seed = 4321
 | 
			
		||||
torch.manual_seed(seed)
 | 
			
		||||
torch.cuda.manual_seed_all(seed)
 | 
			
		||||
| 
						 | 
				
			
			@ -18,14 +25,46 @@ model = QWenLMHeadModel(config)
 | 
			
		|||
 | 
			
		||||
print(model)
 | 
			
		||||
 | 
			
		||||
tokenizer = QWenTokenizer("./qwen.tiktoken")
 | 
			
		||||
tokenizer = QWenTokenizer("./wit_b64.tiktoken", "./wit_char.tiktoken")
 | 
			
		||||
 | 
			
		||||
sys.path.append("..")
 | 
			
		||||
from tools import show
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def Dump_tokens_list(model):
 | 
			
		||||
    tokens = []
 | 
			
		||||
    for token in range(4096):
 | 
			
		||||
        decoded, response, end_reason = decode_tokens(
 | 
			
		||||
            [token],
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            raw_text_len=0,
 | 
			
		||||
            context_length=0,
 | 
			
		||||
            errors="replace",
 | 
			
		||||
        )
 | 
			
		||||
        tokens.append(str(token).zfill(7) + ": " + repr(decoded))
 | 
			
		||||
    show.DumpListToFile(tokens, "./temp/qwen_token_list.txt")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Dump_tokens_list(model)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
model = model.from_pretrained(model_dir).cuda()
 | 
			
		||||
 | 
			
		||||
# state = model.state_dict()
 | 
			
		||||
# torch.save(state, "model_params.pth")
 | 
			
		||||
# model.load_state_dict(torch.load('model_params.pth'))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
model = model.eval()
 | 
			
		||||
# model = model.train()  # control by @torch.no_grad()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
runner = QwenRunner(model)
 | 
			
		||||
 | 
			
		||||
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "")
 | 
			
		||||
output_ids, history, decode_tokens = runner.Chat(tokenizer, "你好", "", 20)
 | 
			
		||||
print(decode_tokens)
 | 
			
		||||
 | 
			
		||||
for i, token in enumerate(output_ids):
 | 
			
		||||
    de = tokenizer.decode([token])
 | 
			
		||||
    de = str(i + 1).zfill(3) + " : " + repr(de)
 | 
			
		||||
    print(de)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -193,14 +193,18 @@ class QwenRunner:
 | 
			
		|||
        tokenizer,
 | 
			
		||||
        query: str,
 | 
			
		||||
        query_assistant: str,
 | 
			
		||||
        gen_length=0,
 | 
			
		||||
        system: str = "You are a helpful assistant.",
 | 
			
		||||
        history=[],
 | 
			
		||||
    ):
 | 
			
		||||
        qwen = self.qwen
 | 
			
		||||
        history = copy.deepcopy(history)
 | 
			
		||||
        self.qwen.config.pad_token_id = tokenizer.eod_id
 | 
			
		||||
        self.qwen.config.eos_token_id = tokenizer.eod_id
 | 
			
		||||
        raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
 | 
			
		||||
        input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
 | 
			
		||||
        self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
 | 
			
		||||
        input_length = input_ids.shape[1]
 | 
			
		||||
        while True:
 | 
			
		||||
            outputs = self.forwardQWen(input_ids)
 | 
			
		||||
            next_token_scores = outputs[:, -1, :]
 | 
			
		||||
| 
						 | 
				
			
			@ -212,6 +216,8 @@ class QwenRunner:
 | 
			
		|||
            if finish:
 | 
			
		||||
                break
 | 
			
		||||
            input_ids = torch.cat([input_ids, next_tokens], dim=-1)
 | 
			
		||||
            if gen_length != 0 and (input_length + gen_length) < input_ids.shape[1]:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        decoded, response, end_reason = decode_tokens(
 | 
			
		||||
            input_ids[0],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,31 +26,21 @@ IMEND = "<|im_end|>"
 | 
			
		|||
# as the default behavior is changed to allow special tokens in
 | 
			
		||||
# regular texts, the surface forms of special tokens need to be
 | 
			
		||||
# as different as possible to minimize the impact
 | 
			
		||||
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
 | 
			
		||||
# changed to use actual index to avoid misconfiguration with vocabulary expansion
 | 
			
		||||
SPECIAL_START_ID = 151643
 | 
			
		||||
SPECIAL_TOKENS = tuple(
 | 
			
		||||
    enumerate(
 | 
			
		||||
        (
 | 
			
		||||
            (
 | 
			
		||||
                ENDOFTEXT,
 | 
			
		||||
                IMSTART,
 | 
			
		||||
                IMEND,
 | 
			
		||||
            )
 | 
			
		||||
            + EXTRAS
 | 
			
		||||
        ),
 | 
			
		||||
        start=SPECIAL_START_ID,
 | 
			
		||||
    )
 | 
			
		||||
)
 | 
			
		||||
SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
 | 
			
		||||
def _load_tiktoken_b64(tiktoken_bpe_file: str, startRank=0) -> Dict[bytes, int]:
 | 
			
		||||
    with open(tiktoken_bpe_file, "rb") as f:
 | 
			
		||||
        contents = f.read()
 | 
			
		||||
    return {
 | 
			
		||||
        base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
 | 
			
		||||
    }
 | 
			
		||||
    ll = contents.splitlines()
 | 
			
		||||
    return {base64.b64decode(token): int(rank + startRank) for rank, token in enumerate(ll)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_tiktoken_char(tiktoken_bpe_file: str, startRank=0) -> Dict[bytes, int]:
 | 
			
		||||
    with open(tiktoken_bpe_file, "rb") as f:
 | 
			
		||||
        contents = f.read()
 | 
			
		||||
    ll = contents.splitlines()
 | 
			
		||||
    return {token: int(rank + startRank) for rank, token in enumerate(ll)}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QWenTokenizer(PreTrainedTokenizer):
 | 
			
		||||
| 
						 | 
				
			
			@ -60,9 +50,9 @@ class QWenTokenizer(PreTrainedTokenizer):
 | 
			
		|||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vocab_file,
 | 
			
		||||
        vocab_file_b64,
 | 
			
		||||
        vocab_file_char,
 | 
			
		||||
        errors="replace",
 | 
			
		||||
        extra_vocab_file=None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__(**kwargs)
 | 
			
		||||
| 
						 | 
				
			
			@ -71,22 +61,24 @@ class QWenTokenizer(PreTrainedTokenizer):
 | 
			
		|||
        # use ignore if you are in streaming inference
 | 
			
		||||
        self.errors = errors
 | 
			
		||||
 | 
			
		||||
        self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)  # type: Dict[bytes, int]
 | 
			
		||||
        self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
 | 
			
		||||
        self.mergeable_ranks = _load_tiktoken_b64(vocab_file_b64)
 | 
			
		||||
        self.mergeable_ranks.update(_load_tiktoken_char(vocab_file_char, len(self.mergeable_ranks)))
 | 
			
		||||
 | 
			
		||||
        # try load extra vocab from file
 | 
			
		||||
        if extra_vocab_file is not None:
 | 
			
		||||
            used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
 | 
			
		||||
            extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
 | 
			
		||||
            for token, index in extra_mergeable_ranks.items():
 | 
			
		||||
                if token in self.mergeable_ranks:
 | 
			
		||||
                    logger.info(f"extra token {token} exists, skipping")
 | 
			
		||||
                    continue
 | 
			
		||||
                if index in used_ids:
 | 
			
		||||
                    logger.info(f"the index {index} for extra token {token} exists, skipping")
 | 
			
		||||
                    continue
 | 
			
		||||
                self.mergeable_ranks[token] = index
 | 
			
		||||
            # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
 | 
			
		||||
        special = (
 | 
			
		||||
            "user",
 | 
			
		||||
            "assistant",
 | 
			
		||||
            ENDOFTEXT,
 | 
			
		||||
            IMSTART,
 | 
			
		||||
            IMEND,
 | 
			
		||||
        )
 | 
			
		||||
        extras = tuple((f"<|extra_{i}|>" for i in range(4096 - len(self.mergeable_ranks) - len(special))))
 | 
			
		||||
        special_tokens = tuple(
 | 
			
		||||
            enumerate(
 | 
			
		||||
                (special + extras),
 | 
			
		||||
                start=len(self.mergeable_ranks),
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        self.special_tokens = {token: index for index, token in special_tokens}
 | 
			
		||||
 | 
			
		||||
        enc = tiktoken.Encoding(
 | 
			
		||||
            "Qwen",
 | 
			
		||||
| 
						 | 
				
			
			@ -95,7 +87,7 @@ class QWenTokenizer(PreTrainedTokenizer):
 | 
			
		|||
            special_tokens=self.special_tokens,
 | 
			
		||||
        )
 | 
			
		||||
        assert (
 | 
			
		||||
            len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
 | 
			
		||||
            len(self.mergeable_ranks) + len(self.special_tokens) <= enc.n_vocab
 | 
			
		||||
        ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
 | 
			
		||||
 | 
			
		||||
        self.decoder = {v: k for k, v in self.mergeable_ranks.items()}  # type: dict[int, bytes|str]
 | 
			
		||||
| 
						 | 
				
			
			@ -153,8 +145,6 @@ class QWenTokenizer(PreTrainedTokenizer):
 | 
			
		|||
            raise ValueError("Adding regular tokens is not supported")
 | 
			
		||||
        for token in new_tokens:
 | 
			
		||||
            surface_form = token.content if isinstance(token, AddedToken) else token
 | 
			
		||||
            if surface_form not in SPECIAL_TOKENS_SET:
 | 
			
		||||
                raise ValueError("Adding unknown special tokens is not supported")
 | 
			
		||||
        return 0
 | 
			
		||||
 | 
			
		||||
    def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,256 @@
 | 
			
		|||
IQ==
 | 
			
		||||
Ig==
 | 
			
		||||
Iw==
 | 
			
		||||
JA==
 | 
			
		||||
JQ==
 | 
			
		||||
Jg==
 | 
			
		||||
Jw==
 | 
			
		||||
KA==
 | 
			
		||||
KQ==
 | 
			
		||||
Kg==
 | 
			
		||||
Kw==
 | 
			
		||||
LA==
 | 
			
		||||
LQ==
 | 
			
		||||
Lg==
 | 
			
		||||
Lw==
 | 
			
		||||
MA==
 | 
			
		||||
MQ==
 | 
			
		||||
Mg==
 | 
			
		||||
Mw==
 | 
			
		||||
NA==
 | 
			
		||||
NQ==
 | 
			
		||||
Ng==
 | 
			
		||||
Nw==
 | 
			
		||||
OA==
 | 
			
		||||
OQ==
 | 
			
		||||
Og==
 | 
			
		||||
Ow==
 | 
			
		||||
PA==
 | 
			
		||||
PQ==
 | 
			
		||||
Pg==
 | 
			
		||||
Pw==
 | 
			
		||||
QA==
 | 
			
		||||
QQ==
 | 
			
		||||
Qg==
 | 
			
		||||
Qw==
 | 
			
		||||
RA==
 | 
			
		||||
RQ==
 | 
			
		||||
Rg==
 | 
			
		||||
Rw==
 | 
			
		||||
SA==
 | 
			
		||||
SQ==
 | 
			
		||||
Sg==
 | 
			
		||||
Sw==
 | 
			
		||||
TA==
 | 
			
		||||
TQ==
 | 
			
		||||
Tg==
 | 
			
		||||
Tw==
 | 
			
		||||
UA==
 | 
			
		||||
UQ==
 | 
			
		||||
Ug==
 | 
			
		||||
Uw==
 | 
			
		||||
VA==
 | 
			
		||||
VQ==
 | 
			
		||||
Vg==
 | 
			
		||||
Vw==
 | 
			
		||||
WA==
 | 
			
		||||
WQ==
 | 
			
		||||
Wg==
 | 
			
		||||
Ww==
 | 
			
		||||
XA==
 | 
			
		||||
XQ==
 | 
			
		||||
Xg==
 | 
			
		||||
Xw==
 | 
			
		||||
YA==
 | 
			
		||||
YQ==
 | 
			
		||||
Yg==
 | 
			
		||||
Yw==
 | 
			
		||||
ZA==
 | 
			
		||||
ZQ==
 | 
			
		||||
Zg==
 | 
			
		||||
Zw==
 | 
			
		||||
aA==
 | 
			
		||||
aQ==
 | 
			
		||||
ag==
 | 
			
		||||
aw==
 | 
			
		||||
bA==
 | 
			
		||||
bQ==
 | 
			
		||||
bg==
 | 
			
		||||
bw==
 | 
			
		||||
cA==
 | 
			
		||||
cQ==
 | 
			
		||||
cg==
 | 
			
		||||
cw==
 | 
			
		||||
dA==
 | 
			
		||||
dQ==
 | 
			
		||||
dg==
 | 
			
		||||
dw==
 | 
			
		||||
eA==
 | 
			
		||||
eQ==
 | 
			
		||||
eg==
 | 
			
		||||
ew==
 | 
			
		||||
fA==
 | 
			
		||||
fQ==
 | 
			
		||||
fg==
 | 
			
		||||
oQ==
 | 
			
		||||
og==
 | 
			
		||||
ow==
 | 
			
		||||
pA==
 | 
			
		||||
pQ==
 | 
			
		||||
pg==
 | 
			
		||||
pw==
 | 
			
		||||
qA==
 | 
			
		||||
qQ==
 | 
			
		||||
qg==
 | 
			
		||||
qw==
 | 
			
		||||
rA==
 | 
			
		||||
rg==
 | 
			
		||||
rw==
 | 
			
		||||
sA==
 | 
			
		||||
sQ==
 | 
			
		||||
sg==
 | 
			
		||||
sw==
 | 
			
		||||
tA==
 | 
			
		||||
tQ==
 | 
			
		||||
tg==
 | 
			
		||||
tw==
 | 
			
		||||
uA==
 | 
			
		||||
uQ==
 | 
			
		||||
ug==
 | 
			
		||||
uw==
 | 
			
		||||
vA==
 | 
			
		||||
vQ==
 | 
			
		||||
vg==
 | 
			
		||||
vw==
 | 
			
		||||
wA==
 | 
			
		||||
wQ==
 | 
			
		||||
wg==
 | 
			
		||||
ww==
 | 
			
		||||
xA==
 | 
			
		||||
xQ==
 | 
			
		||||
xg==
 | 
			
		||||
xw==
 | 
			
		||||
yA==
 | 
			
		||||
yQ==
 | 
			
		||||
yg==
 | 
			
		||||
yw==
 | 
			
		||||
zA==
 | 
			
		||||
zQ==
 | 
			
		||||
zg==
 | 
			
		||||
zw==
 | 
			
		||||
0A==
 | 
			
		||||
0Q==
 | 
			
		||||
0g==
 | 
			
		||||
0w==
 | 
			
		||||
1A==
 | 
			
		||||
1Q==
 | 
			
		||||
1g==
 | 
			
		||||
1w==
 | 
			
		||||
2A==
 | 
			
		||||
2Q==
 | 
			
		||||
2g==
 | 
			
		||||
2w==
 | 
			
		||||
3A==
 | 
			
		||||
3Q==
 | 
			
		||||
3g==
 | 
			
		||||
3w==
 | 
			
		||||
4A==
 | 
			
		||||
4Q==
 | 
			
		||||
4g==
 | 
			
		||||
4w==
 | 
			
		||||
5A==
 | 
			
		||||
5Q==
 | 
			
		||||
5g==
 | 
			
		||||
5w==
 | 
			
		||||
6A==
 | 
			
		||||
6Q==
 | 
			
		||||
6g==
 | 
			
		||||
6w==
 | 
			
		||||
7A==
 | 
			
		||||
7Q==
 | 
			
		||||
7g==
 | 
			
		||||
7w==
 | 
			
		||||
8A==
 | 
			
		||||
8Q==
 | 
			
		||||
8g==
 | 
			
		||||
8w==
 | 
			
		||||
9A==
 | 
			
		||||
9Q==
 | 
			
		||||
9g==
 | 
			
		||||
9w==
 | 
			
		||||
+A==
 | 
			
		||||
+Q==
 | 
			
		||||
+g==
 | 
			
		||||
+w==
 | 
			
		||||
/A==
 | 
			
		||||
/Q==
 | 
			
		||||
/g==
 | 
			
		||||
/w==
 | 
			
		||||
AA==
 | 
			
		||||
AQ==
 | 
			
		||||
Ag==
 | 
			
		||||
Aw==
 | 
			
		||||
BA==
 | 
			
		||||
BQ==
 | 
			
		||||
Bg==
 | 
			
		||||
Bw==
 | 
			
		||||
CA==
 | 
			
		||||
CQ==
 | 
			
		||||
Cg==
 | 
			
		||||
Cw==
 | 
			
		||||
DA==
 | 
			
		||||
DQ==
 | 
			
		||||
Dg==
 | 
			
		||||
Dw==
 | 
			
		||||
EA==
 | 
			
		||||
EQ==
 | 
			
		||||
Eg==
 | 
			
		||||
Ew==
 | 
			
		||||
FA==
 | 
			
		||||
FQ==
 | 
			
		||||
Fg==
 | 
			
		||||
Fw==
 | 
			
		||||
GA==
 | 
			
		||||
GQ==
 | 
			
		||||
Gg==
 | 
			
		||||
Gw==
 | 
			
		||||
HA==
 | 
			
		||||
HQ==
 | 
			
		||||
Hg==
 | 
			
		||||
Hw==
 | 
			
		||||
IA==
 | 
			
		||||
fw==
 | 
			
		||||
gA==
 | 
			
		||||
gQ==
 | 
			
		||||
gg==
 | 
			
		||||
gw==
 | 
			
		||||
hA==
 | 
			
		||||
hQ==
 | 
			
		||||
hg==
 | 
			
		||||
hw==
 | 
			
		||||
iA==
 | 
			
		||||
iQ==
 | 
			
		||||
ig==
 | 
			
		||||
iw==
 | 
			
		||||
jA==
 | 
			
		||||
jQ==
 | 
			
		||||
jg==
 | 
			
		||||
jw==
 | 
			
		||||
kA==
 | 
			
		||||
kQ==
 | 
			
		||||
kg==
 | 
			
		||||
kw==
 | 
			
		||||
lA==
 | 
			
		||||
lQ==
 | 
			
		||||
lg==
 | 
			
		||||
lw==
 | 
			
		||||
mA==
 | 
			
		||||
mQ==
 | 
			
		||||
mg==
 | 
			
		||||
mw==
 | 
			
		||||
nA==
 | 
			
		||||
nQ==
 | 
			
		||||
ng==
 | 
			
		||||
nw==
 | 
			
		||||
oA==
 | 
			
		||||
rQ==
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
		Reference in New Issue