Compare commits

..

2 Commits

Author SHA1 Message Date
Colin 6366b52fef Add reaserch sile resault. 2024-02-04 23:48:51 +08:00
Colin 9d5d590b09 Add dataset and wit. 2024-02-04 23:48:24 +08:00
11 changed files with 152825 additions and 0 deletions

5
dataset/MNBVC.py Normal file
View File

@ -0,0 +1,5 @@
from datasets import load_dataset
dataset = load_dataset("liwu/MNBVC", "wikipedia", split="train", streaming=True)
print(next(iter(dataset))) # get the first line

123
qwen/research_silu.py Normal file
View File

@ -0,0 +1,123 @@
import torch
import sys
import math
from modelscope import snapshot_download
from transformers import AutoTokenizer
from transformers import AutoConfig
from modeling_qwen import QWenLMHeadModel
from modeling_qwen import QwenRunner
import numpy as np
import torch.nn.functional as F
from qwen_generation_utils import (
make_context,
decode_tokens,
)
sys.path.append("..")
from tools import show
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
config, kwargs = AutoConfig.from_pretrained(
"./",
return_unused_kwargs=True,
trust_remote_code=True,
code_revision=None,
_commit_hash=None,
)
model = QWenLMHeadModel(config)
print(model)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = model.from_pretrained(model_dir)
if torch.cuda.device_count() > 0:
model = model.cuda()
model = model.eval()
index = 0
class ResearchRunner(QwenRunner):
def __init__(self, model):
super().__init__(model)
def prepareInput(self, tokenizer, query, query_assistant, history, system):
start_to = [151644]
n_to = [198]
end_to = [151645]
system_str = "system\nYou are a helpful assistant."
user_str = "user\n" + query
aassistant_str = "assistant\n" + query_assistant
system_token = start_to + tokenizer.encode(system_str, allowed_special=set()) + end_to + n_to
user_token = start_to + tokenizer.encode(user_str, allowed_special=set()) + end_to + n_to
aassistant_token = start_to + tokenizer.encode(aassistant_str, allowed_special=set())
tokens = system_token + user_token + aassistant_token
tokens = user_token + aassistant_token
tokens = start_to + tokenizer.encode("user\nHi你好\nassistant\n我是", allowed_special=set())
return "", tokens
def forwardQWenBlock(
self,
block,
hidden_states,
rotary_pos_emb_list=None,
):
layernorm_output = block.ln_1(hidden_states)
attn_outputs = self.forwardAttention(block.attn, layernorm_output, rotary_pos_emb_list)
attn_output = attn_outputs[0]
layernorm_input = attn_output + hidden_states
layernorm_output = block.ln_2(layernorm_input)
a1 = block.mlp.w1(layernorm_output)
a2 = block.mlp.w2(layernorm_output)
activation = (F.relu(a2) > 0).to(float)
act_mean = torch.mean(activation, 2)
print("Layer:" + str(block.index))
print(act_mean.cpu())
global index
if index == 0:
activation = activation.reshape(activation.shape[1], 64, -1)
show.DumpTensorToImage(activation, "./temp/activation_layer_" + str(block.index) + ".png")
intermediate_parallel = a1 * F.silu(a2)
mlp_output = block.mlp.c_proj(intermediate_parallel)
hidden_states = layernorm_input + mlp_output
return hidden_states
def isFinish(self, next_tokens):
global index
index = index + 1
finish, next = super().isFinish(next_tokens)
return finish, next
para = list(model.parameters())
runner = ResearchRunner(model)
output_ids, history, decoded = runner.Chat(tokenizer, "你好!!", "")
print(decoded)
tokens = []
for i, token in enumerate(output_ids):
de = tokenizer.decode([token])
de = str(i + 1).zfill(3) + " : " + repr(de)
tokens.append(de)
show.DumpListToFile(tokens, "./temp/token_decode_list.txt")

2
wit/__init__.py Normal file
View File

@ -0,0 +1,2 @@
from qwen.modeling_qwen import QWenLMHeadModel
from qwen.configuration_qwen import QWenConfig

45
wit/configuration_qwen.py Normal file
View File

@ -0,0 +1,45 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
class QWenConfig:
def __init__(self):
self.vocab_size = 151936
self.hidden_size = 2048
self.num_hidden_layers = 24
self.num_attention_heads = 16
self.emb_dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.layer_norm_epsilon = 1e-6
self.initializer_range = 0.02
self.max_position_embeddings = 8192
self.scale_attn_weights = True
self.use_cache = True
self.bf16 = False
self.fp16 = False
self.fp32 = False
self.kv_channels = 128
self.rotary_pct = 1.0
self.rotary_emb_base = 10000
self.use_dynamic_ntk = True
self.use_logn_attn = True
self.use_flash_attn = "auto"
self.intermediate_size = 11008
self.no_bias = True
self.tie_word_embeddings = False
self.use_cache_quantization = False
self.use_cache_kernel = False
self.softmax_in_fp32 = False
self.chat_format = "chatml"
self.eos_token_id = 151643
self.pad_token_id = 151643
self.max_window_size = 6144
self.max_new_tokens = 512
self.do_sample = True
self.top_k = 0
self.top_p = 0.8
self.repetition_penalty = 1.1
self.model_max_length = 8192

31
wit/demo.py Normal file
View File

@ -0,0 +1,31 @@
import torch
from modelscope import snapshot_download
from modeling_wit import QWenLMHeadModel
from modeling_wit import QwenRunner
from configuration_qwen import QWenConfig
from tokenization_qwen import QWenTokenizer
seed = 4321
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
model_dir = snapshot_download("qwen/Qwen-1_8B-Chat")
# model_dir = "/home/colin/.cache/modelscope/hub/qwen/Qwen-1_8B-Chat"
config = QWenConfig()
model = QWenLMHeadModel(config)
print(model)
tokenizer = QWenTokenizer("./qwen.tiktoken")
model = model.from_pretrained(model_dir).cuda()
model = model.eval()
# model = model.train() # control by @torch.no_grad()
runner = QwenRunner(model)
response, history, decode_tokens = runner.Chat(tokenizer, "东南亚国家日本的首都是什么市", "")
print(decode_tokens)

View File

@ -0,0 +1,202 @@
{
"metadata": {
"total_size": 3673657344
},
"weight_map": {
"lm_head.weight": "model-00002-of-00002.safetensors",
"transformer.h.0.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.0.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.0.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.0.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.0.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.0.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.0.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.0.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.1.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.1.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.1.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.1.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.1.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.1.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.1.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.1.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.10.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.10.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.10.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.10.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.10.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.10.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.10.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.10.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.11.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.11.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.11.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.11.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.11.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.11.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.11.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.11.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.12.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.12.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.12.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.12.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.12.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.12.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.12.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.12.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.13.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.13.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.13.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.13.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.13.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.13.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.13.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.13.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.14.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.14.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.14.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.14.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.14.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.14.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.14.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.14.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.15.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.15.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.15.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.15.ln_1.weight": "model-00002-of-00002.safetensors",
"transformer.h.15.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.15.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.15.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.15.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.16.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.16.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.16.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.16.ln_1.weight": "model-00002-of-00002.safetensors",
"transformer.h.16.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.16.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.16.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.16.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.17.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.17.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.17.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.17.ln_1.weight": "model-00002-of-00002.safetensors",
"transformer.h.17.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.17.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.17.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.17.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.18.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.18.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.18.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.18.ln_1.weight": "model-00002-of-00002.safetensors",
"transformer.h.18.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.18.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.18.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.18.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.19.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.19.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.19.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.19.ln_1.weight": "model-00002-of-00002.safetensors",
"transformer.h.19.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.19.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.19.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.19.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.2.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.2.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.2.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.2.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.2.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.2.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.2.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.2.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.20.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.20.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.20.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.20.ln_1.weight": "model-00002-of-00002.safetensors",
"transformer.h.20.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.20.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.20.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.20.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.21.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.21.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.21.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.21.ln_1.weight": "model-00002-of-00002.safetensors",
"transformer.h.21.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.21.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.21.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.21.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.22.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.22.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.22.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.22.ln_1.weight": "model-00002-of-00002.safetensors",
"transformer.h.22.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.22.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.22.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.22.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.23.attn.c_attn.bias": "model-00002-of-00002.safetensors",
"transformer.h.23.attn.c_attn.weight": "model-00002-of-00002.safetensors",
"transformer.h.23.attn.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.23.ln_1.weight": "model-00002-of-00002.safetensors",
"transformer.h.23.ln_2.weight": "model-00002-of-00002.safetensors",
"transformer.h.23.mlp.c_proj.weight": "model-00002-of-00002.safetensors",
"transformer.h.23.mlp.w1.weight": "model-00002-of-00002.safetensors",
"transformer.h.23.mlp.w2.weight": "model-00002-of-00002.safetensors",
"transformer.h.3.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.3.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.3.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.3.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.3.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.3.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.3.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.3.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.4.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.4.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.4.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.4.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.4.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.4.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.4.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.4.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.5.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.5.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.5.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.5.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.5.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.5.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.5.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.5.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.6.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.6.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.6.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.6.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.6.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.6.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.6.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.6.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.7.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.7.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.7.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.7.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.7.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.7.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.7.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.7.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.8.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.8.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.8.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.8.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.8.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.8.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.8.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.8.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.h.9.attn.c_attn.bias": "model-00001-of-00002.safetensors",
"transformer.h.9.attn.c_attn.weight": "model-00001-of-00002.safetensors",
"transformer.h.9.attn.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.9.ln_1.weight": "model-00001-of-00002.safetensors",
"transformer.h.9.ln_2.weight": "model-00001-of-00002.safetensors",
"transformer.h.9.mlp.c_proj.weight": "model-00001-of-00002.safetensors",
"transformer.h.9.mlp.w1.weight": "model-00001-of-00002.safetensors",
"transformer.h.9.mlp.w2.weight": "model-00001-of-00002.safetensors",
"transformer.ln_f.weight": "model-00002-of-00002.safetensors",
"transformer.wte.weight": "model-00001-of-00002.safetensors"
}
}

389
wit/modeling_wit.py Normal file
View File

@ -0,0 +1,389 @@
import copy
import math
import os
import sys
import gc
from tqdm import auto as tqdm_lib
import json
from typing import Optional, Tuple, Union, Callable, List, Any, Generator
from einops import rearrange
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from torch import nn
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file
from qwen_generation_utils import (
make_context,
decode_tokens,
)
sys.path.append("..")
from tools import show
from tools import mem_tracker
# tracker = mem_tracker.MemTracker()
# tracker.track()
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
class QWenAttention(nn.Module):
def __init__(self, config, index):
super().__init__()
self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.projection_size = config.kv_channels * config.num_attention_heads
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.index = index
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
class QWenMLP(nn.Module):
def __init__(self, config):
super().__init__()
ff_dim_in = config.intermediate_size // 2
self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias=not config.no_bias)
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
class QWenBlock(nn.Module):
def __init__(self, config, index):
super().__init__()
self.ln_1 = RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = QWenAttention(config, index)
self.ln_2 = RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenMLP(config)
self.index = index
class QWenModel(nn.Module):
def __init__(self, config):
super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.hidden_size)
self.drop = nn.Dropout(config.emb_dropout_prob)
dim = config.kv_channels
self.h = nn.ModuleList([QWenBlock(config, i) for i in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(
config.hidden_size,
eps=config.layer_norm_epsilon,
)
self.dim = dim
self.base = config.rotary_emb_base
inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._rotary_pos_emb_cache = None
self._seq_len_cached = 0
self._ntk_alpha_cached = 1.0
def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0):
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim)
)
self._seq_len_cached = max(2 * seqlen, 16)
self._ntk_alpha_cached = ntk_alpha
seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
emb = rearrange(emb, "n d -> 1 n 1 d")
cos, sin = emb.cos(), emb.sin()
self._rotary_pos_emb_cache = [cos, sin]
class QWenLMHeadModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json")
print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
model = cls._load_pretrained_model(resolved_archive_file)
return model
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
def _load_pretrained_model(cls, resolved_archive_file):
start_prefix = ""
model_to_load = cls
if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
for shard_file in resolved_archive_file:
state_dict = safe_load_file(shard_file)
cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
del state_dict # force memory release
gc.collect()
print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
return cls
class QwenRunner:
def __init__(self, qwen):
self.qwen = qwen
@torch.no_grad()
def Chat(
self,
tokenizer,
query: str,
query_assistant: str,
system: str = "You are a helpful assistant.",
history=[],
):
qwen = self.qwen
history = copy.deepcopy(history)
raw_text, context_tokens = self.prepareInput(tokenizer, query, query_assistant, history, system)
input_ids = torch.tensor([context_tokens]).to(next(qwen.parameters()).device)
self.unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
while True:
outputs = self.forwardQWen(input_ids)
next_token_scores = outputs[:, -1, :]
next_token_scores = self.repetition_penalty(input_ids, next_token_scores)
next_token_scores = self.top_p(next_token_scores)
next_tokens = self.sample(next_token_scores)
finish, next_tokens = self.isFinish(next_tokens)
if finish:
break
input_ids = torch.cat([input_ids, next_tokens], dim=-1)
decoded, response, end_reason = decode_tokens(
input_ids[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
errors="replace",
)
history.append((query, response))
return input_ids[0].cpu().tolist(), history, decoded
def _rotate_half(self, x):
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, t, freqs):
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_float = t.float()
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
t_rot = (t_rot * cos) + (self._rotate_half(t_rot) * sin)
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
def split_heads(
self,
attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
):
atten = attention
mixed_x_layer = atten.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(atten.split_size, dim=2)
query = atten._split_heads(query, atten.num_heads, atten.head_dim)
key = atten._split_heads(key, atten.num_heads, atten.head_dim)
value = atten._split_heads(value, atten.num_heads, atten.head_dim)
return query, key, value
def pos_emb(self, query, key, rotary_pos_emb_list):
rotary_pos_emb = rotary_pos_emb_list[0]
rotary_pos_emb = [i[:, -query.shape[1] :, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
query = self.apply_rotary_pos_emb(query, rotary_pos_emb[0])
key = self.apply_rotary_pos_emb(key, rotary_pos_emb[1])
return query, key
def attention(self, attention, query, key, value, causal_mask):
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=causal_mask).transpose(1, 2)
context_layer = attention._merge_heads(attn_output, attention.num_heads, attention.head_dim)
attn_output = attention.c_proj(context_layer)
return attn_output
def build_mask(self, query):
size = query.size(1)
causal_mask = torch.tril(torch.ones((size, size), dtype=torch.bool, device=query.device)).view(1, 1, size, size)
return causal_mask
def forwardAttention(
self,
attention,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
):
query, key, value = self.split_heads(attention, hidden_states)
query, key = self.pos_emb(query, key, rotary_pos_emb_list)
causal_mask = self.build_mask(query)
return self.attention(attention, query, key, value, causal_mask)
def forwardQWenBlock(
self,
block,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
):
layernorm_output = block.ln_1(hidden_states)
attn_outputs = self.forwardAttention(block.attn, layernorm_output, rotary_pos_emb_list)
attn_output = attn_outputs[0]
layernorm_input = attn_output + hidden_states
layernorm_output = block.ln_2(layernorm_input)
a1 = block.mlp.w1(layernorm_output)
a2 = block.mlp.w2(layernorm_output)
intermediate_parallel = a1 * F.silu(a2)
mlp_output = block.mlp.c_proj(intermediate_parallel)
hidden_states = layernorm_input + mlp_output
return hidden_states
def forwardQWen(
self,
input_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
):
transfm = self.qwen.transformer
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = transfm.wte(input_ids)
kv_seq_len = hidden_states.size()[1]
transfm.update_rotary_pos_emb_cache(kv_seq_len, ntk_alpha=1.0)
cos, sin = transfm._rotary_pos_emb_cache
rotary_pos_emb_list = [[cos[:, :kv_seq_len], sin[:, :kv_seq_len]]]
hidden_states = transfm.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
for block in transfm.h:
hidden_states = self.forwardQWenBlock(block, hidden_states, rotary_pos_emb_list=rotary_pos_emb_list)
hidden_states = transfm.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
lm_logits = self.qwen.lm_head(hidden_states)
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
# shift_logits = lm_logits[..., :-1, :].contiguous()
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
# )
# loss.backward()
return lm_logits
def prepareInput(self, tokenizer, query, query_assistant, history, system):
return make_context(tokenizer, query, query_assistant, history=history, system=system)
def repetition_penalty(self, input_ids, next_token_scores):
penalty = self.qwen.config.repetition_penalty
score = torch.gather(next_token_scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * penalty, score / penalty)
next_token_scores = next_token_scores.scatter_(1, input_ids, score)
return next_token_scores
def top_p(self, next_token_scores):
top_p = self.qwen.config.top_p
filter_value = -float("Inf")
min_tokens_to_keep = 1
sorted_logits, sorted_indices = torch.sort(next_token_scores, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# Keep at least min_tokens_to_keep
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_scores = next_token_scores.masked_fill(indices_to_remove, filter_value)
return next_token_scores
def sample(self, next_token_scores):
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
return next_tokens
def isFinish(self, next_tokens):
pad_token_id = self.qwen.config.pad_token_id
eos_token_id_tensor = torch.tensor([self.qwen.config.eos_token_id]).to(next_tokens.device)
next_tokens = next_tokens * self.unfinished_sequences + pad_token_id * (1 - self.unfinished_sequences)
self.unfinished_sequences = self.unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
return self.unfinished_sequences.max() == 0, next_tokens[:, None]

151643
wit/qwen.tiktoken Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,109 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Generation support."""
from typing import Tuple, List, Union, Iterable
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizer
from transformers import logging
from transformers.generation import LogitsProcessor
logger = logging.get_logger(__name__)
# Types.
HistoryType = List[Tuple[str, str]]
TokensType = List[int]
BatchTokensType = List[List[int]]
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
query_assistant: str = "",
history: List[Tuple[str, str]] = None,
system: str = "",
max_window_size: int = 6144,
):
if history is None:
history = []
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(role, allowed_special=set()) + nl_tokens + tokenizer.encode(
content, allowed_special=set()
)
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
assistant_tokens = tokenizer.encode(query_assistant, allowed_special=set())
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
response_text, response_tokens_part = _tokenize_str("assistant", turn_response)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
current_context_size = len(system_tokens) + len(next_context_tokens) + len(context_tokens)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
+ assistant_tokens
)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n{query_assistant}"
return raw_text, context_tokens
def decode_tokens(
tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer,
raw_text_len: int = 0,
context_length: int = 0,
errors: str = "replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
end_reason = f"Gen length {len(tokens)}"
eod_token_idx = context_length
for eod_token_idx in range(context_length, len(tokens)):
if tokens[eod_token_idx] in [tokenizer.im_start_id, tokenizer.im_end_id]:
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
break
decoded = tokenizer.decode(tokens, errors=errors)
decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)
trim_decode_tokens = decode_tokens[raw_text_len:]
trim_decode_tokens = trim_decode_tokens.strip()
return decoded, trim_decode_tokens, end_reason

266
wit/tokenization_qwen.py Normal file
View File

@ -0,0 +1,266 @@
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Tokenization classes for QWen."""
import base64
import logging
import os
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union
import tiktoken
from transformers import PreTrainedTokenizer, AddedToken
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
ENDOFTEXT = "<|endoftext|>"
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
# as the default behavior is changed to allow special tokens in
# regular texts, the surface forms of special tokens need to be
# as different as possible to minimize the impact
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
# changed to use actual index to avoid misconfiguration with vocabulary expansion
SPECIAL_START_ID = 151643
SPECIAL_TOKENS = tuple(
enumerate(
(
(
ENDOFTEXT,
IMSTART,
IMEND,
)
+ EXTRAS
),
start=SPECIAL_START_ID,
)
)
SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)
}
class QWenTokenizer(PreTrainedTokenizer):
"""QWen tokenizer."""
vocab_files_names = VOCAB_FILES_NAMES
def __init__(
self,
vocab_file,
errors="replace",
extra_vocab_file=None,
**kwargs,
):
super().__init__(**kwargs)
# how to handle errors in decoding UTF-8 byte sequences
# use ignore if you are in streaming inference
self.errors = errors
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int]
self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
# try load extra vocab from file
if extra_vocab_file is not None:
used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
for token, index in extra_mergeable_ranks.items():
if token in self.mergeable_ranks:
logger.info(f"extra token {token} exists, skipping")
continue
if index in used_ids:
logger.info(f"the index {index} for extra token {token} exists, skipping")
continue
self.mergeable_ranks[token] = index
# the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
assert (
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
self.decoder = {v: k for k, v in self.mergeable_ranks.items()} # type: dict[int, bytes|str]
self.decoder.update({v: k for k, v in self.special_tokens.items()})
self.tokenizer = enc # type: tiktoken.Encoding
self.eod_id = self.tokenizer.eot_token
self.im_start_id = self.special_tokens[IMSTART]
self.im_end_id = self.special_tokens[IMEND]
def __getstate__(self):
# for pickle lovers
state = self.__dict__.copy()
del state["tokenizer"]
return state
def __setstate__(self, state):
# tokenizer is not python native; don't pass it; rebuild it
self.__dict__.update(state)
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
self.tokenizer = enc
def __len__(self) -> int:
return self.tokenizer.n_vocab
def get_vocab(self) -> Dict[bytes, int]:
return self.mergeable_ranks
def convert_tokens_to_ids(self, tokens: Union[bytes, str, List[Union[bytes, str]]]) -> List[int]:
ids = []
if isinstance(tokens, (str, bytes)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.mergeable_ranks.get(tokens)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.mergeable_ranks.get(token))
return ids
def _add_tokens(
self,
new_tokens: Union[List[str], List[AddedToken]],
special_tokens: bool = False,
) -> int:
if not special_tokens and new_tokens:
raise ValueError("Adding regular tokens is not supported")
for token in new_tokens:
surface_form = token.content if isinstance(token, AddedToken) else token
if surface_form not in SPECIAL_TOKENS_SET:
raise ValueError("Adding unknown special tokens is not supported")
return 0
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
"""
Save only the vocabulary of the tokenizer (vocabulary).
Returns:
`Tuple(str)`: Paths to the files saved.
"""
file_path = os.path.join(save_directory, "qwen.tiktoken")
with open(file_path, "w", encoding="utf8") as w:
for k, v in self.mergeable_ranks.items():
line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
w.write(line)
return (file_path,)
def tokenize(
self,
text: str,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = (),
**kwargs,
) -> List[Union[bytes, str]]:
"""
Converts a string in a sequence of tokens.
Args:
text (`str`):
The sequence to be encoded.
allowed_special (`Literal["all"]` or `set`):
The surface forms of the tokens to be encoded as special tokens in regular texts.
Default to "all".
disallowed_special (`Literal["all"]` or `Collection`):
The surface forms of the tokens that should not be in regular texts and trigger errors.
Default to an empty tuple.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific encode method.
Returns:
`List[bytes|str]`: The list of tokens.
"""
tokens = []
text = unicodedata.normalize("NFC", text)
# this implementation takes a detour: text -> token id -> token surface forms
for t in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
tokens.append(self.decoder[t])
return tokens
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors=self.errors)
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type types or str")
if temp:
text += temp.decode("utf-8", errors=self.errors)
return text
@property
def vocab_size(self):
return self.tokenizer.n_vocab
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
"""Converts an id to a token, special tokens included"""
if index in self.decoder:
return self.decoder[index]
raise ValueError("unknown ids")
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
"""Converts a token to an id using the vocab, special tokens included"""
if token in self.special_tokens:
return self.special_tokens[token]
if token in self.mergeable_ranks:
return self.mergeable_ranks[token]
raise ValueError("unknown token")
def _tokenize(self, text: str, **kwargs):
"""
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
Do NOT take care of added tokens.
"""
raise NotImplementedError
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: str = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
token_ids = [i for i in token_ids if i < self.eod_id]
return self.tokenizer.decode(token_ids, errors=errors or self.errors)

10
wit/tokenizer_config.json Normal file
View File

@ -0,0 +1,10 @@
{
"model_max_length": 8192,
"tokenizer_class": "QWenTokenizer",
"auto_map": {
"AutoTokenizer": [
"tokenization_qwen.QWenTokenizer",
null
]
}
}