diff --git a/qwen/config.json b/qwen/config.json new file mode 100644 index 0000000..bccf46f --- /dev/null +++ b/qwen/config.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "QWenLMHeadModel" + ], + "auto_map": { + "AutoConfig": "configuration_qwen.QWenConfig", + "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel" + }, + "attn_dropout_prob": 0.0, + "bf16": false, + "emb_dropout_prob": 0.0, + "fp16": false, + "fp32": false, + "hidden_size": 2048, + "intermediate_size": 11008, + "initializer_range": 0.02, + "kv_channels": 128, + "layer_norm_epsilon": 1e-06, + "max_position_embeddings": 8192, + "model_type": "qwen", + "no_bias": true, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "onnx_safe": null, + "rotary_emb_base": 10000, + "rotary_pct": 1.0, + "scale_attn_weights": true, + "seq_length": 8192, + "tie_word_embeddings": false, + "tokenizer_class": "QWenTokenizer", + "transformers_version": "4.32.0", + "use_cache": true, + "use_dynamic_ntk": true, + "use_flash_attn": "auto", + "use_logn_attn": true, + "vocab_size": 151936 +} \ No newline at end of file diff --git a/qwen/configuration.json b/qwen/configuration.json new file mode 100644 index 0000000..db02e99 --- /dev/null +++ b/qwen/configuration.json @@ -0,0 +1,5 @@ +{ + "framework": "pytorch", + "task": "chat", + "allow_remote": true +} diff --git a/qwen/configuration_qwen.py b/qwen/configuration_qwen.py new file mode 100644 index 0000000..f8fe2cb --- /dev/null +++ b/qwen/configuration_qwen.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from transformers import PretrainedConfig + + +class QWenConfig(PretrainedConfig): + model_type = "qwen" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=4096, + num_hidden_layers=32, + num_attention_heads=32, + emb_dropout_prob=0.0, + attn_dropout_prob=0.0, + layer_norm_epsilon=1e-6, + initializer_range=0.02, + max_position_embeddings=8192, + scale_attn_weights=True, + use_cache=True, + bf16=False, + fp16=False, + fp32=False, + kv_channels=128, + rotary_pct=1.0, + rotary_emb_base=10000, + use_dynamic_ntk=True, + use_logn_attn=True, + use_flash_attn="auto", + intermediate_size=22016, + no_bias=True, + tie_word_embeddings=False, + use_cache_quantization=False, + use_cache_kernel=False, + softmax_in_fp32=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.emb_dropout_prob = emb_dropout_prob + self.attn_dropout_prob = attn_dropout_prob + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.max_position_embeddings = max_position_embeddings + self.bf16 = bf16 + self.fp16 = fp16 + self.fp32 = fp32 + self.kv_channels = kv_channels + self.rotary_pct = rotary_pct + self.rotary_emb_base = rotary_emb_base + self.use_dynamic_ntk = use_dynamic_ntk + self.use_logn_attn = use_logn_attn + self.use_flash_attn = use_flash_attn + self.no_bias = no_bias + self.use_cache_quantization = use_cache_quantization + self.use_cache_kernel = use_cache_kernel + self.softmax_in_fp32 = softmax_in_fp32 + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) diff --git a/qwen/cpp_kernels.py b/qwen/cpp_kernels.py new file mode 100644 index 0000000..d9cee70 --- /dev/null +++ b/qwen/cpp_kernels.py @@ -0,0 +1,55 @@ +from torch.utils import cpp_extension +import pathlib +import os +import subprocess + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") + +# Check if cuda 11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) +if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + if int(bare_metal_minor) >= 7: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_90,code=sm_90') + +# Build path +srcpath = pathlib.Path(__file__).parent.absolute() +buildpath = srcpath / 'build' +_create_build_dir(buildpath) + +def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3', ], + extra_cuda_cflags=['-O3', + '-gencode', 'arch=compute_70,code=sm_70', + '--use_fast_math'] + extra_cuda_flags + cc_flag, + verbose=1 + ) + +extra_flags = [] + +cache_autogptq_cuda_256_sources = ["./cache_autogptq_cuda_256.cpp", + "./cache_autogptq_cuda_kernel_256.cu"] +cache_autogptq_cuda_256 = _cpp_extention_load_helper("cache_autogptq_cuda_256", cache_autogptq_cuda_256_sources, extra_flags) diff --git a/qwen/demo.py b/qwen/demo.py index 05d58d6..189e462 100644 --- a/qwen/demo.py +++ b/qwen/demo.py @@ -1,11 +1,30 @@ + +import torch from modelscope import snapshot_download from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import GenerationConfig +from transformers import AutoConfig + +from modeling_qwen import QWenLMHeadModel + +seed = 4321 +torch.manual_seed(seed) +torch.cuda.manual_seed_all(seed) model_dir = snapshot_download("qwen/Qwen-1_8B-Chat") +config, kwargs = AutoConfig.from_pretrained( + model_dir, + return_unused_kwargs=True, + trust_remote_code=True, + code_revision=None, + _commit_hash=None, +) +model = QWenLMHeadModel(config) + + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) -model = AutoModelForCausalLM.from_pretrained( +model = model.from_pretrained( model_dir, device_map="auto", trust_remote_code=True ).eval() diff --git a/qwen/generation_config.json b/qwen/generation_config.json new file mode 100644 index 0000000..a66a26d --- /dev/null +++ b/qwen/generation_config.json @@ -0,0 +1,12 @@ +{ + "chat_format": "chatml", + "eos_token_id": 151643, + "pad_token_id": 151643, + "max_window_size": 6144, + "max_new_tokens": 512, + "do_sample": true, + "top_k": 0, + "top_p": 0.8, + "repetition_penalty": 1.1, + "transformers_version": "4.31.0" +} \ No newline at end of file diff --git a/qwen/model.safetensors.index.json b/qwen/model.safetensors.index.json new file mode 100644 index 0000000..9192b9d --- /dev/null +++ b/qwen/model.safetensors.index.json @@ -0,0 +1,202 @@ +{ + "metadata": { + "total_size": 3673657344 + }, + "weight_map": { + "lm_head.weight": "model-00002-of-00002.safetensors", + "transformer.h.0.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.0.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.0.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.0.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.0.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.0.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.0.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.0.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.1.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.1.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.1.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.1.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.1.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.1.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.1.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.1.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.10.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.10.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.10.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.10.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.10.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.10.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.10.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.10.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.11.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.11.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.11.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.11.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.11.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.11.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.11.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.11.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.12.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.12.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.12.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.12.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.12.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.12.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.12.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.12.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.13.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.13.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.13.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.13.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.13.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.13.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.13.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.13.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.14.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.14.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.14.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.14.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.14.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.14.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.14.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.14.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.15.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.15.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.15.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.15.ln_1.weight": "model-00002-of-00002.safetensors", + "transformer.h.15.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.15.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.15.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.15.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.16.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.16.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.16.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.16.ln_1.weight": "model-00002-of-00002.safetensors", + "transformer.h.16.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.16.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.16.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.16.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.17.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.17.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.17.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.17.ln_1.weight": "model-00002-of-00002.safetensors", + "transformer.h.17.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.17.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.17.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.17.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.18.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.18.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.18.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.18.ln_1.weight": "model-00002-of-00002.safetensors", + "transformer.h.18.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.18.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.18.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.18.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.19.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.19.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.19.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.19.ln_1.weight": "model-00002-of-00002.safetensors", + "transformer.h.19.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.19.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.19.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.19.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.2.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.2.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.2.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.2.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.2.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.2.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.2.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.2.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.20.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.20.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.20.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.20.ln_1.weight": "model-00002-of-00002.safetensors", + "transformer.h.20.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.20.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.20.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.20.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.21.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.21.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.21.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.21.ln_1.weight": "model-00002-of-00002.safetensors", + "transformer.h.21.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.21.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.21.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.21.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.22.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.22.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.22.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.22.ln_1.weight": "model-00002-of-00002.safetensors", + "transformer.h.22.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.22.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.22.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.22.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.23.attn.c_attn.bias": "model-00002-of-00002.safetensors", + "transformer.h.23.attn.c_attn.weight": "model-00002-of-00002.safetensors", + "transformer.h.23.attn.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.23.ln_1.weight": "model-00002-of-00002.safetensors", + "transformer.h.23.ln_2.weight": "model-00002-of-00002.safetensors", + "transformer.h.23.mlp.c_proj.weight": "model-00002-of-00002.safetensors", + "transformer.h.23.mlp.w1.weight": "model-00002-of-00002.safetensors", + "transformer.h.23.mlp.w2.weight": "model-00002-of-00002.safetensors", + "transformer.h.3.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.3.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.3.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.3.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.3.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.3.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.3.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.3.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.4.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.4.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.4.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.4.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.4.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.4.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.4.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.4.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.5.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.5.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.5.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.5.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.5.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.5.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.5.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.5.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.6.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.6.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.6.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.6.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.6.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.6.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.6.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.6.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.7.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.7.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.7.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.7.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.7.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.7.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.7.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.7.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.8.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.8.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.8.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.8.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.8.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.8.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.8.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.8.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.h.9.attn.c_attn.bias": "model-00001-of-00002.safetensors", + "transformer.h.9.attn.c_attn.weight": "model-00001-of-00002.safetensors", + "transformer.h.9.attn.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.9.ln_1.weight": "model-00001-of-00002.safetensors", + "transformer.h.9.ln_2.weight": "model-00001-of-00002.safetensors", + "transformer.h.9.mlp.c_proj.weight": "model-00001-of-00002.safetensors", + "transformer.h.9.mlp.w1.weight": "model-00001-of-00002.safetensors", + "transformer.h.9.mlp.w2.weight": "model-00001-of-00002.safetensors", + "transformer.ln_f.weight": "model-00002-of-00002.safetensors", + "transformer.wte.weight": "model-00001-of-00002.safetensors" + } +} diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index 6500378..93a691e 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -40,8 +40,8 @@ SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 -from .configuration_qwen import QWenConfig -from .qwen_generation_utils import ( +from configuration_qwen import QWenConfig +from qwen_generation_utils import ( HistoryType, make_context, decode_tokens, @@ -520,7 +520,9 @@ class QWenAttention(nn.Module): if not self.use_cache_quantization and SUPPORT_TORCH2: if attention_mask is not None: - attention_mask = attention_mask.expand(-1, -1, query.size(2), -1) + attention_mask = attention_mask.expand( + -1, -1, causal_mask.size(2), -1 + ) if causal_mask is not None: attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) else: @@ -1328,14 +1330,14 @@ def apply_rotary_pos_emb(t, freqs): t (tensor(batch_size, seq_len, n_head, head_dim)): the input embedding/hidden states freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]): - the cached cos/sin position embeddings + the cached cos/sin position embeddings """ rot_dim = freqs[0].shape[-1] cos, sin = freqs t_float = t.float() if apply_rotary_emb_func is not None and t.is_cuda: - # apply_rotary_emb in flash_attn requires cos/sin to be of - # shape (seqlen, rotary_dim / 2) and apply rotary embedding + # apply_rotary_emb in flash_attn requires cos/sin to be of + # shape (seqlen, rotary_dim / 2) and apply rotary embedding # to the first rotary_dim of the input cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2] sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2] @@ -1360,4 +1362,4 @@ class RMSNorm(torch.nn.Module): return rms_norm(x, self.weight, self.eps) else: output = self._norm(x.float()).type_as(x) - return output * self.weight \ No newline at end of file + return output * self.weight diff --git a/qwen/qwen_generation_utils.py b/qwen/qwen_generation_utils.py new file mode 100644 index 0000000..4e8e1d8 --- /dev/null +++ b/qwen/qwen_generation_utils.py @@ -0,0 +1,416 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Generation support.""" + +from typing import Tuple, List, Union, Iterable + +import numpy as np +import torch +import torch.nn.functional as F +from transformers import PreTrainedTokenizer +from transformers import logging +from transformers.generation import LogitsProcessor + +logger = logging.get_logger(__name__) + +# Types. +HistoryType = List[Tuple[str, str]] +TokensType = List[int] +BatchTokensType = List[List[int]] + + +def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: + for tokens in batch: + context_length = len(tokens) + if context_length < seq_length: + tokens.extend([pad_id] * (seq_length - context_length)) + return batch + + +def get_ltor_masks_and_position_ids( + data, + eod_token, + reset_position_ids, + reset_attention_mask, + eod_mask_loss, +): + """Build masks and position id for left to right model.""" + + # Extract batch size and sequence length. + micro_batch_size, seq_length = data.size() + + # Attention mask (lower triangular). + if reset_attention_mask: + att_mask_batch = micro_batch_size + else: + att_mask_batch = 1 + attention_mask = torch.tril( + torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) + ).view(att_mask_batch, 1, seq_length, seq_length) + + # Loss mask. + loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) + if eod_mask_loss: + loss_mask[data == eod_token] = 0.0 + + # Position ids. + position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) + position_ids = position_ids.unsqueeze(0).expand_as(data) + # We need to clone as the ids will be modifed based on batch index. + if reset_position_ids: + position_ids = position_ids.clone() + + if reset_position_ids or reset_attention_mask: + # Loop through the batches: + for b in range(micro_batch_size): + + # Find indecies where EOD token is. + eod_index = position_ids[b, data[b] == eod_token] + # Detach indecies from positions if going to modify positions. + if reset_position_ids: + eod_index = eod_index.clone() + + # Loop through EOD indecies: + prev_index = 0 + for j in range(eod_index.size()[0]): + i = eod_index[j] + # Mask attention loss. + if reset_attention_mask: + attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 + # Reset positions. + if reset_position_ids: + position_ids[b, (i + 1) :] -= i + 1 - prev_index + prev_index = i + 1 + + # Convert attention mask to binary: + attention_mask = attention_mask < 0.5 + + return attention_mask, loss_mask, position_ids + + +def get_batch(context_tokens: torch.LongTensor, eod_id: int): + """Generate batch from context tokens.""" + # Move to GPU. + tokens = context_tokens.contiguous().to(context_tokens.device) + # Get the attention mask and postition ids. + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + tokens, + eod_id, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + ) + return tokens, attention_mask, position_ids + + +def get_stop_words_ids(chat_format, tokenizer): + if chat_format == "raw": + stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] + elif chat_format == "chatml": + stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] + else: + raise NotImplementedError(f"Unknown chat format {chat_format!r}") + return stop_words_ids + + +def make_context( + tokenizer: PreTrainedTokenizer, + query: str, + history: List[Tuple[str, str]] = None, + system: str = "", + max_window_size: int = 6144, + chat_format: str = "chatml", +): + if history is None: + history = [] + + if chat_format == "chatml": + im_start, im_end = "<|im_start|>", "<|im_end|>" + im_start_tokens = [tokenizer.im_start_id] + im_end_tokens = [tokenizer.im_end_id] + nl_tokens = tokenizer.encode("\n") + + def _tokenize_str(role, content): + return f"{role}\n{content}", tokenizer.encode( + role, allowed_special=set() + ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) + + system_text, system_tokens_part = _tokenize_str("system", system) + system_tokens = im_start_tokens + system_tokens_part + im_end_tokens + + raw_text = "" + context_tokens = [] + + for turn_query, turn_response in reversed(history): + query_text, query_tokens_part = _tokenize_str("user", turn_query) + query_tokens = im_start_tokens + query_tokens_part + im_end_tokens + response_text, response_tokens_part = _tokenize_str( + "assistant", turn_response + ) + response_tokens = im_start_tokens + response_tokens_part + im_end_tokens + + next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens + prev_chat = ( + f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" + ) + + current_context_size = ( + len(system_tokens) + len(next_context_tokens) + len(context_tokens) + ) + if current_context_size < max_window_size: + context_tokens = next_context_tokens + context_tokens + raw_text = prev_chat + raw_text + else: + break + + context_tokens = system_tokens + context_tokens + raw_text = f"{im_start}{system_text}{im_end}" + raw_text + context_tokens += ( + nl_tokens + + im_start_tokens + + _tokenize_str("user", query)[1] + + im_end_tokens + + nl_tokens + + im_start_tokens + + tokenizer.encode("assistant") + + nl_tokens + ) + raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" + + elif chat_format == "raw": + raw_text = query + context_tokens = tokenizer.encode(raw_text) + else: + raise NotImplementedError(f"Unknown chat format {chat_format!r}") + + return raw_text, context_tokens + + +def _decode_default( + tokens: List[int], + *, + stop_words: List[str], + eod_words: List[str], + tokenizer: PreTrainedTokenizer, + raw_text_len: int, + verbose: bool = False, + return_end_reason: bool = False, + errors: str='replace', +): + trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] + if verbose: + print("\nRaw Generate: ", trim_decode_tokens) + + end_reason = f"Gen length {len(tokens)}" + for stop_word in stop_words: + trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() + for eod_word in eod_words: + if eod_word in trim_decode_tokens: + end_reason = f"Gen {eod_word!r}" + trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] + trim_decode_tokens = trim_decode_tokens.strip() + if verbose: + print("\nEnd Reason:", end_reason) + print("\nGenerate: ", trim_decode_tokens) + + if return_end_reason: + return trim_decode_tokens, end_reason + else: + return trim_decode_tokens + + +def _decode_chatml( + tokens: List[int], + *, + stop_words: List[str], + eod_token_ids: List[int], + tokenizer: PreTrainedTokenizer, + raw_text_len: int, + context_length: int, + verbose: bool = False, + return_end_reason: bool = False, + errors: str='replace' +): + end_reason = f"Gen length {len(tokens)}" + eod_token_idx = context_length + for eod_token_idx in range(context_length, len(tokens)): + if tokens[eod_token_idx] in eod_token_ids: + end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" + break + + trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:] + if verbose: + print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:]) + print("\nRaw Generate:", trim_decode_tokens) + print("\nEnd Reason:", end_reason) + for stop_word in stop_words: + trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() + trim_decode_tokens = trim_decode_tokens.strip() + if verbose: + print("\nGenerate:", trim_decode_tokens) + + if return_end_reason: + return trim_decode_tokens, end_reason + else: + return trim_decode_tokens + + +def decode_tokens( + tokens: Union[torch.LongTensor, TokensType], + tokenizer: PreTrainedTokenizer, + raw_text_len: int, + context_length: int, + chat_format: str, + verbose: bool = False, + return_end_reason: bool = False, + errors: str="replace", +) -> str: + if torch.is_tensor(tokens): + tokens = tokens.cpu().numpy().tolist() + + if chat_format == "chatml": + return _decode_chatml( + tokens, + stop_words=[], + eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], + tokenizer=tokenizer, + raw_text_len=raw_text_len, + context_length=context_length, + verbose=verbose, + return_end_reason=return_end_reason, + errors=errors, + ) + elif chat_format == "raw": + return _decode_default( + tokens, + stop_words=["<|endoftext|>"], + eod_words=["<|endoftext|>"], + tokenizer=tokenizer, + raw_text_len=raw_text_len, + verbose=verbose, + return_end_reason=return_end_reason, + errors=errors, + ) + else: + raise NotImplementedError(f"Unknown chat format {chat_format!r}") + + +class StopWordsLogitsProcessor(LogitsProcessor): + """ + :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. + + Args: + stop_words_ids (:obj:`List[List[int]]`): + List of list of token ids of stop ids. In order to get the tokens of the words + that should not appear in the generated text, use :obj:`tokenizer(bad_word, + add_prefix_space=True).input_ids`. + eos_token_id (:obj:`int`): + The id of the `end-of-sequence` token. + """ + + def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): + + if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: + raise ValueError( + f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." + ) + if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): + raise ValueError( + f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." + ) + if any( + any( + (not isinstance(token_id, (int, np.integer)) or token_id < 0) + for token_id in stop_word_ids + ) + for stop_word_ids in stop_words_ids + ): + raise ValueError( + f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." + ) + + self.stop_words_ids = list( + filter( + lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids + ) + ) + self.eos_token_id = eos_token_id + for stop_token_seq in self.stop_words_ids: + assert ( + len(stop_token_seq) > 0 + ), "Stop words token sequences {} cannot have an empty list".format( + stop_words_ids + ) + + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor + ) -> torch.FloatTensor: + stopped_samples = self._calc_stopped_samples(input_ids) + for i, should_stop in enumerate(stopped_samples): + if should_stop: + scores[i, self.eos_token_id] = float(2**15) + return scores + + def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: + if len(tokens) == 0: + # if bad word tokens is just one token always ban it + return True + elif len(tokens) > len(prev_tokens): + # if bad word tokens are longer then prev input_ids they can't be equal + return False + elif prev_tokens[-len(tokens) :].tolist() == tokens: + # if tokens match + return True + else: + return False + + def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: + stopped_samples = [] + for prev_input_ids_slice in prev_input_ids: + match = False + for stop_token_seq in self.stop_words_ids: + if self._tokens_match(prev_input_ids_slice, stop_token_seq): + # if tokens do not match continue + match = True + break + stopped_samples.append(match) + + return stopped_samples + + +def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): + """This function has been mostly taken from huggingface conversational + ai code at + https://medium.com/huggingface/how-to-build-a-state-of-the-art- + conversational-ai-with-transfer-learning-2d818ac26313""" + + if top_k > 0: + # Remove all tokens with a probability less than the + # last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + # Cconvert to 1D + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token + # above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + for i in range(sorted_indices.size(0)): + indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] + logits[i][indices_to_remove] = filter_value + + return logits + + +def switch(val1, val2, boolean): + boolean = boolean.type_as(val1) + return (1 - boolean) * val1 + boolean * val2 diff --git a/qwen/tokenization_qwen.py b/qwen/tokenization_qwen.py new file mode 100644 index 0000000..2a526d6 --- /dev/null +++ b/qwen/tokenization_qwen.py @@ -0,0 +1,276 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Tokenization classes for QWen.""" + +import base64 +import logging +import os +import unicodedata +from typing import Collection, Dict, List, Set, Tuple, Union + +import tiktoken +from transformers import PreTrainedTokenizer, AddedToken + +logger = logging.getLogger(__name__) + + +VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"} + +PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" +ENDOFTEXT = "<|endoftext|>" +IMSTART = "<|im_start|>" +IMEND = "<|im_end|>" +# as the default behavior is changed to allow special tokens in +# regular texts, the surface forms of special tokens need to be +# as different as possible to minimize the impact +EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) +# changed to use actual index to avoid misconfiguration with vocabulary expansion +SPECIAL_START_ID = 151643 +SPECIAL_TOKENS = tuple( + enumerate( + ( + ( + ENDOFTEXT, + IMSTART, + IMEND, + ) + + EXTRAS + ), + start=SPECIAL_START_ID, + ) +) +SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) + + +def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: + with open(tiktoken_bpe_file, "rb") as f: + contents = f.read() + return { + base64.b64decode(token): int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) + } + + +class QWenTokenizer(PreTrainedTokenizer): + """QWen tokenizer.""" + + vocab_files_names = VOCAB_FILES_NAMES + + def __init__( + self, + vocab_file, + errors="replace", + extra_vocab_file=None, + **kwargs, + ): + super().__init__(**kwargs) + + # how to handle errors in decoding UTF-8 byte sequences + # use ignore if you are in streaming inference + self.errors = errors + + self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: Dict[bytes, int] + self.special_tokens = { + token: index + for index, token in SPECIAL_TOKENS + } + + # try load extra vocab from file + if extra_vocab_file is not None: + used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values()) + extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file) + for token, index in extra_mergeable_ranks.items(): + if token in self.mergeable_ranks: + logger.info(f"extra token {token} exists, skipping") + continue + if index in used_ids: + logger.info(f'the index {index} for extra token {token} exists, skipping') + continue + self.mergeable_ranks[token] = index + # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this + + enc = tiktoken.Encoding( + "Qwen", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + assert ( + len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab + ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" + + self.decoder = { + v: k for k, v in self.mergeable_ranks.items() + } # type: dict[int, bytes|str] + self.decoder.update({v: k for k, v in self.special_tokens.items()}) + + self.tokenizer = enc # type: tiktoken.Encoding + + self.eod_id = self.tokenizer.eot_token + self.im_start_id = self.special_tokens[IMSTART] + self.im_end_id = self.special_tokens[IMEND] + + def __getstate__(self): + # for pickle lovers + state = self.__dict__.copy() + del state["tokenizer"] + return state + + def __setstate__(self, state): + # tokenizer is not python native; don't pass it; rebuild it + self.__dict__.update(state) + enc = tiktoken.Encoding( + "Qwen", + pat_str=PAT_STR, + mergeable_ranks=self.mergeable_ranks, + special_tokens=self.special_tokens, + ) + self.tokenizer = enc + + def __len__(self) -> int: + return self.tokenizer.n_vocab + + def get_vocab(self) -> Dict[bytes, int]: + return self.mergeable_ranks + + def convert_tokens_to_ids( + self, tokens: Union[bytes, str, List[Union[bytes, str]]] + ) -> List[int]: + ids = [] + if isinstance(tokens, (str, bytes)): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.mergeable_ranks.get(tokens) + for token in tokens: + if token in self.special_tokens: + ids.append(self.special_tokens[token]) + else: + ids.append(self.mergeable_ranks.get(token)) + return ids + + def _add_tokens( + self, + new_tokens: Union[List[str], List[AddedToken]], + special_tokens: bool = False, + ) -> int: + if not special_tokens and new_tokens: + raise ValueError("Adding regular tokens is not supported") + for token in new_tokens: + surface_form = token.content if isinstance(token, AddedToken) else token + if surface_form not in SPECIAL_TOKENS_SET: + raise ValueError("Adding unknown special tokens is not supported") + return 0 + + def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: + """ + Save only the vocabulary of the tokenizer (vocabulary). + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + file_path = os.path.join(save_directory, "qwen.tiktoken") + with open(file_path, "w", encoding="utf8") as w: + for k, v in self.mergeable_ranks.items(): + line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" + w.write(line) + return (file_path,) + + def tokenize( + self, + text: str, + allowed_special: Union[Set, str] = "all", + disallowed_special: Union[Collection, str] = (), + **kwargs, + ) -> List[Union[bytes, str]]: + """ + Converts a string in a sequence of tokens. + + Args: + text (`str`): + The sequence to be encoded. + allowed_special (`Literal["all"]` or `set`): + The surface forms of the tokens to be encoded as special tokens in regular texts. + Default to "all". + disallowed_special (`Literal["all"]` or `Collection`): + The surface forms of the tokens that should not be in regular texts and trigger errors. + Default to an empty tuple. + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific encode method. + + Returns: + `List[bytes|str]`: The list of tokens. + """ + tokens = [] + text = unicodedata.normalize("NFC", text) + + # this implementation takes a detour: text -> token id -> token surface forms + for t in self.tokenizer.encode( + text, allowed_special=allowed_special, disallowed_special=disallowed_special + ): + tokens.append(self.decoder[t]) + return tokens + + def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: + """ + Converts a sequence of tokens in a single string. + """ + text = "" + temp = b"" + for t in tokens: + if isinstance(t, str): + if temp: + text += temp.decode("utf-8", errors=self.errors) + temp = b"" + text += t + elif isinstance(t, bytes): + temp += t + else: + raise TypeError("token should only be of type types or str") + if temp: + text += temp.decode("utf-8", errors=self.errors) + return text + + @property + def vocab_size(self): + return self.tokenizer.n_vocab + + def _convert_id_to_token(self, index: int) -> Union[bytes, str]: + """Converts an id to a token, special tokens included""" + if index in self.decoder: + return self.decoder[index] + raise ValueError("unknown ids") + + def _convert_token_to_id(self, token: Union[bytes, str]) -> int: + """Converts a token to an id using the vocab, special tokens included""" + if token in self.special_tokens: + return self.special_tokens[token] + if token in self.mergeable_ranks: + return self.mergeable_ranks[token] + raise ValueError("unknown token") + + def _tokenize(self, text: str, **kwargs): + """ + Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based + vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). + + Do NOT take care of added tokens. + """ + raise NotImplementedError + + def _decode( + self, + token_ids: Union[int, List[int]], + skip_special_tokens: bool = False, + errors: str = None, + **kwargs, + ) -> str: + if isinstance(token_ids, int): + token_ids = [token_ids] + if skip_special_tokens: + token_ids = [i for i in token_ids if i < self.eod_id] + return self.tokenizer.decode(token_ids, errors=errors or self.errors) diff --git a/qwen/tokenizer_config.json b/qwen/tokenizer_config.json new file mode 100644 index 0000000..9c37cac --- /dev/null +++ b/qwen/tokenizer_config.json @@ -0,0 +1,10 @@ +{ + "model_max_length": 8192, + "tokenizer_class": "QWenTokenizer", + "auto_map": { + "AutoTokenizer": [ + "tokenization_qwen.QWenTokenizer", + null + ] + } +}