Delete no used files.
This commit is contained in:
		
							parent
							
								
									a15e55bead
								
							
						
					
					
						commit
						89c12380cb
					
				
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
					@ -1,30 +0,0 @@
 | 
				
			||||||
{
 | 
					 | 
				
			||||||
  "architectures": [
 | 
					 | 
				
			||||||
    "T5ForConditionalGeneration"
 | 
					 | 
				
			||||||
  ],
 | 
					 | 
				
			||||||
  "d_ff": 2048,
 | 
					 | 
				
			||||||
  "d_kv": 64,
 | 
					 | 
				
			||||||
  "d_model": 768,
 | 
					 | 
				
			||||||
  "decoder_start_token_id": 0,
 | 
					 | 
				
			||||||
  "dense_act_fn": "gelu_new",
 | 
					 | 
				
			||||||
  "dropout_rate": 0.1,
 | 
					 | 
				
			||||||
  "eos_token_id": 1,
 | 
					 | 
				
			||||||
  "feed_forward_proj": "gated-gelu",
 | 
					 | 
				
			||||||
  "initializer_factor": 1.0,
 | 
					 | 
				
			||||||
  "is_encoder_decoder": true,
 | 
					 | 
				
			||||||
  "is_gated_act": true,
 | 
					 | 
				
			||||||
  "layer_norm_epsilon": 1e-06,
 | 
					 | 
				
			||||||
  "model_type": "t5",
 | 
					 | 
				
			||||||
  "num_decoder_layers": 12,
 | 
					 | 
				
			||||||
  "num_heads": 12,
 | 
					 | 
				
			||||||
  "num_layers": 12,
 | 
					 | 
				
			||||||
  "output_past": true,
 | 
					 | 
				
			||||||
  "pad_token_id": 0,
 | 
					 | 
				
			||||||
  "relative_attention_max_distance": 128,
 | 
					 | 
				
			||||||
  "relative_attention_num_buckets": 32,
 | 
					 | 
				
			||||||
  "tie_word_embeddings": false,
 | 
					 | 
				
			||||||
  "torch_dtype": "float32",
 | 
					 | 
				
			||||||
  "transformers_version": "4.26.0.dev0",
 | 
					 | 
				
			||||||
  "use_cache": true,
 | 
					 | 
				
			||||||
  "vocab_size": 32128
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,11 +0,0 @@
 | 
				
			||||||
{
 | 
					 | 
				
			||||||
    "framework": "pytorch",
 | 
					 | 
				
			||||||
    "task": "text2text-generation",
 | 
					 | 
				
			||||||
    "model": {
 | 
					 | 
				
			||||||
        "type": "T5",
 | 
					 | 
				
			||||||
        "language": "zh"
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    "pipeline": {
 | 
					 | 
				
			||||||
        "type": "text2text-generation"
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,175 +0,0 @@
 | 
				
			||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
 | 
					 | 
				
			||||||
# Copyright 2020, The T5 Authors and HuggingFace Inc.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
# You may obtain a copy of the License at
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
# limitations under the License.
 | 
					 | 
				
			||||||
""" T5 model configuration"""
 | 
					 | 
				
			||||||
from typing import Mapping
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from transformers.configuration_utils import PretrainedConfig
 | 
					 | 
				
			||||||
from transformers.onnx import OnnxSeq2SeqConfigWithPast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from modelscope.utils.logger import get_logger
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logger = get_logger()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class T5Config(PretrainedConfig):
 | 
					 | 
				
			||||||
    r"""
 | 
					 | 
				
			||||||
    This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
 | 
					 | 
				
			||||||
    instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
 | 
					 | 
				
			||||||
    configuration with the defaults will yield a similar configuration to that of the T5
 | 
					 | 
				
			||||||
    [t5-small](https://huggingface.co/t5-small) architecture.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
 | 
					 | 
				
			||||||
    documentation from [`PretrainedConfig`] for more information.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Arguments:
 | 
					 | 
				
			||||||
        vocab_size (`int`, *optional*, defaults to 32128):
 | 
					 | 
				
			||||||
            Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
 | 
					 | 
				
			||||||
            `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
 | 
					 | 
				
			||||||
        d_model (`int`, *optional*, defaults to 512):
 | 
					 | 
				
			||||||
            Size of the encoder layers and the pooler layer.
 | 
					 | 
				
			||||||
        d_kv (`int`, *optional*, defaults to 64):
 | 
					 | 
				
			||||||
            Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
 | 
					 | 
				
			||||||
            num_heads`.
 | 
					 | 
				
			||||||
        d_ff (`int`, *optional*, defaults to 2048):
 | 
					 | 
				
			||||||
            Size of the intermediate feed forward layer in each `T5Block`.
 | 
					 | 
				
			||||||
        num_layers (`int`, *optional*, defaults to 6):
 | 
					 | 
				
			||||||
            Number of hidden layers in the Transformer encoder.
 | 
					 | 
				
			||||||
        num_decoder_layers (`int`, *optional*):
 | 
					 | 
				
			||||||
            Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
 | 
					 | 
				
			||||||
        num_heads (`int`, *optional*, defaults to 8):
 | 
					 | 
				
			||||||
            Number of attention heads for each attention layer in the Transformer encoder.
 | 
					 | 
				
			||||||
        relative_attention_num_buckets (`int`, *optional*, defaults to 32):
 | 
					 | 
				
			||||||
            The number of buckets to use for each attention layer.
 | 
					 | 
				
			||||||
        relative_attention_max_distance (`int`, *optional*, defaults to 128):
 | 
					 | 
				
			||||||
            The maximum distance of the longer sequences for the bucket separation.
 | 
					 | 
				
			||||||
        dropout_rate (`float`, *optional*, defaults to 0.1):
 | 
					 | 
				
			||||||
            The ratio for all dropout layers.
 | 
					 | 
				
			||||||
        layer_norm_eps (`float`, *optional*, defaults to 1e-6):
 | 
					 | 
				
			||||||
            The epsilon used by the layer normalization layers.
 | 
					 | 
				
			||||||
        initializer_factor (`float`, *optional*, defaults to 1):
 | 
					 | 
				
			||||||
            A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
 | 
					 | 
				
			||||||
            testing).
 | 
					 | 
				
			||||||
        feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
 | 
					 | 
				
			||||||
            Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
 | 
					 | 
				
			||||||
            `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
 | 
					 | 
				
			||||||
        use_cache (`bool`, *optional*, defaults to `True`):
 | 
					 | 
				
			||||||
            Whether or not the model should return the last key/values attentions (not used by all models).
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    model_type = 't5'
 | 
					 | 
				
			||||||
    keys_to_ignore_at_inference = ['past_key_values']
 | 
					 | 
				
			||||||
    attribute_map = {
 | 
					 | 
				
			||||||
        'hidden_size': 'd_model',
 | 
					 | 
				
			||||||
        'num_attention_heads': 'num_heads',
 | 
					 | 
				
			||||||
        'num_hidden_layers': 'num_layers'
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self,
 | 
					 | 
				
			||||||
                 vocab_size=32128,
 | 
					 | 
				
			||||||
                 d_model=512,
 | 
					 | 
				
			||||||
                 d_kv=64,
 | 
					 | 
				
			||||||
                 d_ff=2048,
 | 
					 | 
				
			||||||
                 num_layers=6,
 | 
					 | 
				
			||||||
                 num_decoder_layers=None,
 | 
					 | 
				
			||||||
                 num_heads=8,
 | 
					 | 
				
			||||||
                 relative_attention_num_buckets=32,
 | 
					 | 
				
			||||||
                 relative_attention_max_distance=128,
 | 
					 | 
				
			||||||
                 dropout_rate=0.1,
 | 
					 | 
				
			||||||
                 layer_norm_epsilon=1e-6,
 | 
					 | 
				
			||||||
                 initializer_factor=1.0,
 | 
					 | 
				
			||||||
                 feed_forward_proj='relu',
 | 
					 | 
				
			||||||
                 is_encoder_decoder=True,
 | 
					 | 
				
			||||||
                 use_cache=True,
 | 
					 | 
				
			||||||
                 pad_token_id=0,
 | 
					 | 
				
			||||||
                 eos_token_id=1,
 | 
					 | 
				
			||||||
                 **kwargs):
 | 
					 | 
				
			||||||
        self.vocab_size = vocab_size
 | 
					 | 
				
			||||||
        self.d_model = d_model
 | 
					 | 
				
			||||||
        self.d_kv = d_kv
 | 
					 | 
				
			||||||
        self.d_ff = d_ff
 | 
					 | 
				
			||||||
        self.num_layers = num_layers
 | 
					 | 
				
			||||||
        self.num_decoder_layers = (num_decoder_layers if num_decoder_layers
 | 
					 | 
				
			||||||
                                   is not None else self.num_layers
 | 
					 | 
				
			||||||
                                   )  # default = symmetry
 | 
					 | 
				
			||||||
        self.num_heads = num_heads
 | 
					 | 
				
			||||||
        self.relative_attention_num_buckets = relative_attention_num_buckets
 | 
					 | 
				
			||||||
        self.relative_attention_max_distance = relative_attention_max_distance
 | 
					 | 
				
			||||||
        self.dropout_rate = dropout_rate
 | 
					 | 
				
			||||||
        self.layer_norm_epsilon = layer_norm_epsilon
 | 
					 | 
				
			||||||
        self.initializer_factor = initializer_factor
 | 
					 | 
				
			||||||
        self.feed_forward_proj = feed_forward_proj
 | 
					 | 
				
			||||||
        self.use_cache = use_cache
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        act_info = self.feed_forward_proj.split('-')
 | 
					 | 
				
			||||||
        self.dense_act_fn = act_info[-1]
 | 
					 | 
				
			||||||
        self.is_gated_act = act_info[0] == 'gated'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if len(act_info) > 1 and act_info[0] != 'gated' or len(act_info) > 2:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                f'`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer.'
 | 
					 | 
				
			||||||
                'Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. '
 | 
					 | 
				
			||||||
                "'gated-gelu' or 'relu'")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # for backwards compatibility
 | 
					 | 
				
			||||||
        if feed_forward_proj == 'gated-gelu':
 | 
					 | 
				
			||||||
            self.dense_act_fn = 'gelu_new'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            pad_token_id=pad_token_id,
 | 
					 | 
				
			||||||
            eos_token_id=eos_token_id,
 | 
					 | 
				
			||||||
            is_encoder_decoder=is_encoder_decoder,
 | 
					 | 
				
			||||||
            **kwargs,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class T5OnnxConfig(OnnxSeq2SeqConfigWithPast):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
 | 
					 | 
				
			||||||
        common_inputs = {
 | 
					 | 
				
			||||||
            'input_ids': {
 | 
					 | 
				
			||||||
                0: 'batch',
 | 
					 | 
				
			||||||
                1: 'encoder_sequence'
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
            'attention_mask': {
 | 
					 | 
				
			||||||
                0: 'batch',
 | 
					 | 
				
			||||||
                1: 'encoder_sequence'
 | 
					 | 
				
			||||||
            },
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        if self.use_past:
 | 
					 | 
				
			||||||
            common_inputs['attention_mask'][
 | 
					 | 
				
			||||||
                1] = 'past_encoder_sequence + sequence'
 | 
					 | 
				
			||||||
            common_inputs['decoder_input_ids'] = {0: 'batch'}
 | 
					 | 
				
			||||||
            common_inputs['decoder_attention_mask'] = {
 | 
					 | 
				
			||||||
                0: 'batch',
 | 
					 | 
				
			||||||
                1: 'past_decoder_sequence + sequence'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            common_inputs['decoder_input_ids'] = {
 | 
					 | 
				
			||||||
                0: 'batch',
 | 
					 | 
				
			||||||
                1: 'decoder_sequence'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            common_inputs['decoder_attention_mask'] = {
 | 
					 | 
				
			||||||
                0: 'batch',
 | 
					 | 
				
			||||||
                1: 'decoder_sequence'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.use_past:
 | 
					 | 
				
			||||||
            self.fill_with_past_key_values_(common_inputs, direction='inputs')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return common_inputs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def default_onnx_opset(self) -> int:
 | 
					 | 
				
			||||||
        return 13
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,89 +0,0 @@
 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from modelscope.pipelines import pipeline
 | 
					 | 
				
			||||||
from modelscope.utils.constant import Tasks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from modeling_t5 import T5ForConditionalGeneration
 | 
					 | 
				
			||||||
from modelscope.utils.config import Config
 | 
					 | 
				
			||||||
from configuration import T5Config
 | 
					 | 
				
			||||||
from modelscope import snapshot_download
 | 
					 | 
				
			||||||
from transformers import AutoConfig
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
seed = 4321
 | 
					 | 
				
			||||||
torch.manual_seed(seed)
 | 
					 | 
				
			||||||
torch.cuda.manual_seed_all(seed)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
model_dir = snapshot_download("ClueAI/PromptCLUE-base-v1-5")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# config = T5Config()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
config, kwargs = AutoConfig.from_pretrained(
 | 
					 | 
				
			||||||
    model_dir,
 | 
					 | 
				
			||||||
    return_unused_kwargs=True,
 | 
					 | 
				
			||||||
    trust_remote_code=True,
 | 
					 | 
				
			||||||
    code_revision=None,
 | 
					 | 
				
			||||||
    _commit_hash=None,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
model = T5ForConditionalGeneration(config)
 | 
					 | 
				
			||||||
model = model.from_pretrained(model_dir)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
preprocessor = TextGenerationTransformersPreprocessor(model_dir)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
out = preprocessor._tokenize_text("生成与下列文字相同意")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
tokenizer = preprocessor.nlp_tokenizer
 | 
					 | 
				
			||||||
response, history = model.chat(tokenizer, "生成与下列文字相同意", history=[])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# model = T5ForConditionalGeneration.from_pretrained(
 | 
					 | 
				
			||||||
#     "ClueAI/PromptCLUE-base-v1-5", revision="v0.1"
 | 
					 | 
				
			||||||
# )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
pipeline_t2t = pipeline(
 | 
					 | 
				
			||||||
    task=Tasks.text2text_generation, model=model, preprocessor=preprocessor
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
print(pipeline_t2t("生成与下列文字相同意思的句子:\n白云遍地无人扫\n答案:", do_sample=True, top_p=0.8))
 | 
					 | 
				
			||||||
# {'text': '白云散去无踪,没人扫。'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t('改写下面的文字,确保意思相同:\n一个如此藐视本国人民民主权利的人,怎么可能捍卫外国人的民权?\n答案:', do_sample=True, top_p=0.8))
 | 
					 | 
				
			||||||
# # {'text': '对一个如此藐视本国人民民主权利的人,怎么能捍卫外国人的民权?'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t('根据问题给出答案:\n问题:手指发麻的主要可能病因是:\n答案'))
 | 
					 | 
				
			||||||
# # {'text': '神经损伤,颈椎病,贫血,高血压'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t('问答:\n问题:黄果悬钩子的目是:\n答案:'))
 | 
					 | 
				
			||||||
# # {'text': '蔷薇目'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t('情感分析:\n这个看上去还可以,但其实我不喜欢\n选项:积极,消极'))
 | 
					 | 
				
			||||||
# # {'text': '消极'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t("下面句子是否表示了相同的语义:\n文本1:糖尿病腿麻木怎么办?\n文本2:糖尿病怎样控制生活方式\n选项:相似,不相似\n答案:"))
 | 
					 | 
				
			||||||
# # {'text': '不相似'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t('这是关于哪方面的新闻:\n如果日本沉没,中国会接收日本难民吗?\n选项:故事,文化,娱乐,体育,财经,房产,汽车,教育,科技,军事,旅游,国际,股票,农业,游戏'))
 | 
					 | 
				
			||||||
# # {'text': '国际'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t("阅读文本抽取关键信息:\n张玄武1990年出生中国国籍无境外居留权博士学历现任杭州线锁科技技术总监。\n问题:机构,人名,职位,籍贯,专业,国籍,学历,种族\n答案:"))
 | 
					 | 
				
			||||||
# # {'text': '机构:杭州线锁科技技术_人名:张玄武_职位:博士学历'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t("翻译成英文:\n杀不死我的只会让我更强大\n答案:"))
 | 
					 | 
				
			||||||
# # {'text': 'To kill my life only let me stronger'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t('为下面的文章生成摘要:\n北京时间9月5日12时52分,四川甘孜藏族自治州泸定县发生6.8级地震。地震发生后,领导高度重视并作出重要指示,要求把抢救生命作为首要任务,全力救援受灾群众,最大限度减少人员伤亡'))
 | 
					 | 
				
			||||||
# # {'text': '四川甘孜发生6.8级地震'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t("推理关系判断:\n前提:小明今天在北京\n假设:小明在深圳旅游\n选项:矛盾,蕴含,中立\n答案:"))
 | 
					 | 
				
			||||||
# # {'text': '蕴涵'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t('阅读以下对话并回答问题。\n男:今天怎么这么晚才来上班啊?女:昨天工作到很晚,而且我还感冒了。男:那你回去休息吧,我帮你请假。女:谢谢你。\n问题:女的怎么样?\n选项:正在工作,感冒了,在打电话,要出差。'))
 | 
					 | 
				
			||||||
# # {'text': '感冒了'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t("文本纠错:\n告诉二营长,叫他彻回来,我李云龙从不打没有准备的杖\n答案:"))
 | 
					 | 
				
			||||||
# #{'text':'告诉二营长,叫他下来,我李云龙从不打没有准备的仗'}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# print(pipeline_t2t("问答:\n问题:小米的创始人是谁?\n答案:"))
 | 
					 | 
				
			||||||
# # {'text': '小米创始人:雷军'}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,406 +0,0 @@
 | 
				
			||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
 | 
					 | 
				
			||||||
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
# You may obtain a copy of the License at
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
# limitations under the License.
 | 
					 | 
				
			||||||
import copy
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
from typing import Optional, Tuple, Union, List, Dict, Any
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from torch import nn
 | 
					 | 
				
			||||||
from torch.nn import CrossEntropyLoss
 | 
					 | 
				
			||||||
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from modelscope.metainfo import Models
 | 
					 | 
				
			||||||
from modelscope.models.builder import MODELS
 | 
					 | 
				
			||||||
from modelscope.outputs import (
 | 
					 | 
				
			||||||
    AttentionBackboneModelOutput,
 | 
					 | 
				
			||||||
    Seq2SeqLMOutput,
 | 
					 | 
				
			||||||
    TokenGeneratorOutput,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from modelscope.utils.constant import Tasks
 | 
					 | 
				
			||||||
from modelscope.utils.logger import get_logger
 | 
					 | 
				
			||||||
from backbone import T5PreTrainedModel, T5Stack
 | 
					 | 
				
			||||||
from configuration import T5Config
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logger = get_logger()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
 | 
					 | 
				
			||||||
__HEAD_MASK_WARNING_MSG = """
 | 
					 | 
				
			||||||
The input argument `head_mask` was split into two arguments `head_mask` and
 | 
					 | 
				
			||||||
`decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`,
 | 
					 | 
				
			||||||
but this feature is deprecated and will be removed in future versions. If you do
 | 
					 | 
				
			||||||
not want to use any `decoder_head_mask` now, please set `decoder_head_mask =
 | 
					 | 
				
			||||||
torch.ones(num_layers, num_heads)`.
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class T5ForConditionalGeneration(T5PreTrainedModel):
 | 
					 | 
				
			||||||
    _keys_to_ignore_on_load_missing = [
 | 
					 | 
				
			||||||
        r"encoder\.embed_tokens\.weight",
 | 
					 | 
				
			||||||
        r"decoder\.embed_tokens\.weight",
 | 
					 | 
				
			||||||
        r"lm_head\.weight",
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
    _keys_to_ignore_on_load_unexpected = [
 | 
					 | 
				
			||||||
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
 | 
					 | 
				
			||||||
    ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, config: T5Config, device_map=None, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(config)
 | 
					 | 
				
			||||||
        self.model_dim = config.d_model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        encoder_config = copy.deepcopy(config)
 | 
					 | 
				
			||||||
        encoder_config.is_decoder = False
 | 
					 | 
				
			||||||
        encoder_config.use_cache = False
 | 
					 | 
				
			||||||
        encoder_config.is_encoder_decoder = False
 | 
					 | 
				
			||||||
        self.encoder = T5Stack(encoder_config, self.shared)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        decoder_config = copy.deepcopy(config)
 | 
					 | 
				
			||||||
        decoder_config.is_decoder = True
 | 
					 | 
				
			||||||
        decoder_config.is_encoder_decoder = False
 | 
					 | 
				
			||||||
        decoder_config.num_layers = config.num_decoder_layers
 | 
					 | 
				
			||||||
        self.decoder = T5Stack(decoder_config, self.shared)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Initialize weights and apply final processing
 | 
					 | 
				
			||||||
        self.post_init()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Model parallel
 | 
					 | 
				
			||||||
        self.model_parallel = False
 | 
					 | 
				
			||||||
        if device_map == "auto":
 | 
					 | 
				
			||||||
            self.parallelize()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def parallelize(self, device_map=None):
 | 
					 | 
				
			||||||
        self.device_map = (
 | 
					 | 
				
			||||||
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
 | 
					 | 
				
			||||||
            if device_map is None
 | 
					 | 
				
			||||||
            else device_map
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        assert_device_map(self.device_map, len(self.encoder.block))
 | 
					 | 
				
			||||||
        self.encoder.parallelize(self.device_map)
 | 
					 | 
				
			||||||
        self.decoder.parallelize(self.device_map)
 | 
					 | 
				
			||||||
        self.lm_head = self.lm_head.to(self.decoder.first_device)
 | 
					 | 
				
			||||||
        self.model_parallel = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def deparallelize(self):
 | 
					 | 
				
			||||||
        self.encoder.deparallelize()
 | 
					 | 
				
			||||||
        self.decoder.deparallelize()
 | 
					 | 
				
			||||||
        self.encoder = self.encoder.to("cpu")
 | 
					 | 
				
			||||||
        self.decoder = self.decoder.to("cpu")
 | 
					 | 
				
			||||||
        self.lm_head = self.lm_head.to("cpu")
 | 
					 | 
				
			||||||
        self.model_parallel = False
 | 
					 | 
				
			||||||
        self.device_map = None
 | 
					 | 
				
			||||||
        torch.cuda.empty_cache()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_input_embeddings(self):
 | 
					 | 
				
			||||||
        return self.shared
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def set_input_embeddings(self, new_embeddings):
 | 
					 | 
				
			||||||
        self.shared = new_embeddings
 | 
					 | 
				
			||||||
        self.encoder.set_input_embeddings(new_embeddings)
 | 
					 | 
				
			||||||
        self.decoder.set_input_embeddings(new_embeddings)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def set_output_embeddings(self, new_embeddings):
 | 
					 | 
				
			||||||
        self.lm_head = new_embeddings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_output_embeddings(self):
 | 
					 | 
				
			||||||
        return self.lm_head
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_encoder(self):
 | 
					 | 
				
			||||||
        return self.encoder
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_decoder(self):
 | 
					 | 
				
			||||||
        return self.decoder
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        input_ids: Optional[torch.LongTensor] = None,
 | 
					 | 
				
			||||||
        attention_mask: Optional[torch.FloatTensor] = None,
 | 
					 | 
				
			||||||
        decoder_input_ids: Optional[torch.LongTensor] = None,
 | 
					 | 
				
			||||||
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
 | 
					 | 
				
			||||||
        head_mask: Optional[torch.FloatTensor] = None,
 | 
					 | 
				
			||||||
        decoder_head_mask: Optional[torch.FloatTensor] = None,
 | 
					 | 
				
			||||||
        cross_attn_head_mask: Optional[torch.Tensor] = None,
 | 
					 | 
				
			||||||
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
 | 
					 | 
				
			||||||
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
 | 
					 | 
				
			||||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
					 | 
				
			||||||
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
					 | 
				
			||||||
        labels: Optional[torch.LongTensor] = None,
 | 
					 | 
				
			||||||
        use_cache: Optional[bool] = None,
 | 
					 | 
				
			||||||
        output_attentions: Optional[bool] = None,
 | 
					 | 
				
			||||||
        output_hidden_states: Optional[bool] = None,
 | 
					 | 
				
			||||||
        return_dict: Optional[bool] = None,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					 | 
				
			||||||
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
 | 
					 | 
				
			||||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
					 | 
				
			||||||
        return_dict = (
 | 
					 | 
				
			||||||
            return_dict if return_dict is not None else self.config.use_return_dict
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
 | 
					 | 
				
			||||||
        if head_mask is not None and decoder_head_mask is None:
 | 
					 | 
				
			||||||
            if self.config.num_layers == self.config.num_decoder_layers:
 | 
					 | 
				
			||||||
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
 | 
					 | 
				
			||||||
                decoder_head_mask = head_mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Encode if needed (training, first prediction pass)
 | 
					 | 
				
			||||||
        if encoder_outputs is None:
 | 
					 | 
				
			||||||
            # Convert encoder inputs in embeddings if needed
 | 
					 | 
				
			||||||
            encoder_outputs = self.encoder(
 | 
					 | 
				
			||||||
                input_ids=input_ids,
 | 
					 | 
				
			||||||
                attention_mask=attention_mask,
 | 
					 | 
				
			||||||
                inputs_embeds=inputs_embeds,
 | 
					 | 
				
			||||||
                head_mask=head_mask,
 | 
					 | 
				
			||||||
                output_attentions=output_attentions,
 | 
					 | 
				
			||||||
                output_hidden_states=output_hidden_states,
 | 
					 | 
				
			||||||
                return_dict=return_dict,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        elif return_dict and not isinstance(
 | 
					 | 
				
			||||||
            encoder_outputs, AttentionBackboneModelOutput
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            encoder_outputs = AttentionBackboneModelOutput(
 | 
					 | 
				
			||||||
                last_hidden_state=encoder_outputs[0],
 | 
					 | 
				
			||||||
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
 | 
					 | 
				
			||||||
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        hidden_states = encoder_outputs[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.model_parallel:
 | 
					 | 
				
			||||||
            torch.cuda.set_device(self.decoder.first_device)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if (
 | 
					 | 
				
			||||||
            labels is not None
 | 
					 | 
				
			||||||
            and decoder_input_ids is None
 | 
					 | 
				
			||||||
            and decoder_inputs_embeds is None
 | 
					 | 
				
			||||||
        ):
 | 
					 | 
				
			||||||
            # get decoder inputs from shifting lm labels to the right
 | 
					 | 
				
			||||||
            decoder_input_ids = self._shift_right(labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Set device for model parallelism
 | 
					 | 
				
			||||||
        if self.model_parallel:
 | 
					 | 
				
			||||||
            torch.cuda.set_device(self.decoder.first_device)
 | 
					 | 
				
			||||||
            hidden_states = hidden_states.to(self.decoder.first_device)
 | 
					 | 
				
			||||||
            if decoder_input_ids is not None:
 | 
					 | 
				
			||||||
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
 | 
					 | 
				
			||||||
            if attention_mask is not None:
 | 
					 | 
				
			||||||
                attention_mask = attention_mask.to(self.decoder.first_device)
 | 
					 | 
				
			||||||
            if decoder_attention_mask is not None:
 | 
					 | 
				
			||||||
                decoder_attention_mask = decoder_attention_mask.to(
 | 
					 | 
				
			||||||
                    self.decoder.first_device
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Decode
 | 
					 | 
				
			||||||
        decoder_outputs = self.decoder(
 | 
					 | 
				
			||||||
            input_ids=decoder_input_ids,
 | 
					 | 
				
			||||||
            attention_mask=decoder_attention_mask,
 | 
					 | 
				
			||||||
            inputs_embeds=decoder_inputs_embeds,
 | 
					 | 
				
			||||||
            past_key_values=past_key_values,
 | 
					 | 
				
			||||||
            encoder_hidden_states=hidden_states,
 | 
					 | 
				
			||||||
            encoder_attention_mask=attention_mask,
 | 
					 | 
				
			||||||
            head_mask=decoder_head_mask,
 | 
					 | 
				
			||||||
            cross_attn_head_mask=cross_attn_head_mask,
 | 
					 | 
				
			||||||
            use_cache=use_cache,
 | 
					 | 
				
			||||||
            output_attentions=output_attentions,
 | 
					 | 
				
			||||||
            output_hidden_states=output_hidden_states,
 | 
					 | 
				
			||||||
            return_dict=return_dict,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        sequence_output = decoder_outputs[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Set device for model parallelism
 | 
					 | 
				
			||||||
        if self.model_parallel:
 | 
					 | 
				
			||||||
            torch.cuda.set_device(self.encoder.first_device)
 | 
					 | 
				
			||||||
            self.lm_head = self.lm_head.to(self.encoder.first_device)
 | 
					 | 
				
			||||||
            sequence_output = sequence_output.to(self.lm_head.weight.device)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.config.tie_word_embeddings:
 | 
					 | 
				
			||||||
            # Rescale output before projecting on vocab See
 | 
					 | 
				
			||||||
            # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
 | 
					 | 
				
			||||||
            sequence_output = sequence_output * (self.model_dim ** -0.5)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        lm_logits = self.lm_head(sequence_output)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        loss = None
 | 
					 | 
				
			||||||
        if labels is not None:
 | 
					 | 
				
			||||||
            loss_fct = CrossEntropyLoss(ignore_index=-100)
 | 
					 | 
				
			||||||
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
 | 
					 | 
				
			||||||
            # TODO(thom): Add z_loss
 | 
					 | 
				
			||||||
            # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not return_dict:
 | 
					 | 
				
			||||||
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
 | 
					 | 
				
			||||||
            return ((loss,) + output) if loss is not None else output
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return Seq2SeqLMOutput(
 | 
					 | 
				
			||||||
            loss=loss,
 | 
					 | 
				
			||||||
            logits=lm_logits,
 | 
					 | 
				
			||||||
            past_key_values=decoder_outputs.past_key_values,
 | 
					 | 
				
			||||||
            decoder_hidden_states=decoder_outputs.hidden_states,
 | 
					 | 
				
			||||||
            decoder_attentions=decoder_outputs.attentions,
 | 
					 | 
				
			||||||
            cross_attentions=decoder_outputs.cross_attentions,
 | 
					 | 
				
			||||||
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
 | 
					 | 
				
			||||||
            encoder_hidden_states=encoder_outputs.hidden_states,
 | 
					 | 
				
			||||||
            encoder_attentions=encoder_outputs.attentions,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def prepare_inputs_for_generation(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        input_ids,
 | 
					 | 
				
			||||||
        past=None,
 | 
					 | 
				
			||||||
        attention_mask=None,
 | 
					 | 
				
			||||||
        head_mask=None,
 | 
					 | 
				
			||||||
        decoder_head_mask=None,
 | 
					 | 
				
			||||||
        cross_attn_head_mask=None,
 | 
					 | 
				
			||||||
        use_cache=None,
 | 
					 | 
				
			||||||
        encoder_outputs=None,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # cut decoder_input_ids if past is used
 | 
					 | 
				
			||||||
        if past is not None:
 | 
					 | 
				
			||||||
            input_ids = input_ids[:, -1:]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return {
 | 
					 | 
				
			||||||
            "decoder_input_ids": input_ids,
 | 
					 | 
				
			||||||
            "past_key_values": past,
 | 
					 | 
				
			||||||
            "encoder_outputs": encoder_outputs,
 | 
					 | 
				
			||||||
            "attention_mask": attention_mask,
 | 
					 | 
				
			||||||
            "head_mask": head_mask,
 | 
					 | 
				
			||||||
            "decoder_head_mask": decoder_head_mask,
 | 
					 | 
				
			||||||
            "cross_attn_head_mask": cross_attn_head_mask,
 | 
					 | 
				
			||||||
            "use_cache": use_cache,
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
 | 
					 | 
				
			||||||
        return self._shift_right(labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        *args,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        output = super().generate(*args, **kwargs)
 | 
					 | 
				
			||||||
        return TokenGeneratorOutput(
 | 
					 | 
				
			||||||
            sequences=output if isinstance(output, torch.Tensor) else output[0]
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def chat(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        tokenizer,
 | 
					 | 
				
			||||||
        query: str,
 | 
					 | 
				
			||||||
        history: List[Tuple[str, str]] = None,
 | 
					 | 
				
			||||||
        role: str = "user",
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        if history is None:
 | 
					 | 
				
			||||||
            history = []
 | 
					 | 
				
			||||||
        token = tokenizer(query)
 | 
					 | 
				
			||||||
        inputs = torch.as_tensor([token["input_ids"]])
 | 
					 | 
				
			||||||
        inputs_tensor = inputs.to(next(self.parameters()).device)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        generation_config = copy.deepcopy(self.generation_config)
 | 
					 | 
				
			||||||
        # inputs_tensor = inputs["input_ids"]
 | 
					 | 
				
			||||||
        input_ids = inputs_tensor.repeat_interleave(
 | 
					 | 
				
			||||||
            generation_config.num_return_sequences, dim=0
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        outputs = self.sample(
 | 
					 | 
				
			||||||
            input_ids,
 | 
					 | 
				
			||||||
            generation_config.pad_token_id,
 | 
					 | 
				
			||||||
            generation_config.eos_token_id,
 | 
					 | 
				
			||||||
            generation_config.output_hidden_states,
 | 
					 | 
				
			||||||
            tokenizer,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        outputs = outputs.tolist()[0][:]
 | 
					 | 
				
			||||||
        response = tokenizer.decode(outputs)
 | 
					 | 
				
			||||||
        history.append({"role": role, "content": query})
 | 
					 | 
				
			||||||
        return response, history
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def sample(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        input_ids: torch.LongTensor,
 | 
					 | 
				
			||||||
        pad_token_id: Optional[int] = None,
 | 
					 | 
				
			||||||
        eos_token_id: Optional[Union[int, List[int]]] = None,
 | 
					 | 
				
			||||||
        output_hidden_states: Optional[bool] = None,
 | 
					 | 
				
			||||||
        tokenizer=None,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        if isinstance(eos_token_id, int):
 | 
					 | 
				
			||||||
            eos_token_id = [eos_token_id]
 | 
					 | 
				
			||||||
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        isFinished = torch.zeros(
 | 
					 | 
				
			||||||
            input_ids.shape[0], dtype=torch.long, device=input_ids.device
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        # token_count = 0
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            input_ids_in = input_ids
 | 
					 | 
				
			||||||
            # batch_size, seq_length = input_ids_in.shape
 | 
					 | 
				
			||||||
            # position_ids_in = (
 | 
					 | 
				
			||||||
            #     torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
 | 
					 | 
				
			||||||
            #     .unsqueeze(0)
 | 
					 | 
				
			||||||
            #     .repeat(batch_size, 1)
 | 
					 | 
				
			||||||
            # )
 | 
					 | 
				
			||||||
            # model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # input_ids_in = self.prepare_inputs_for_generation(input_ids)
 | 
					 | 
				
			||||||
            probs, next_tokens = self(input_ids)
 | 
					 | 
				
			||||||
            #     **model_inputs,
 | 
					 | 
				
			||||||
            #     output_hidden_states=output_hidden_states,
 | 
					 | 
				
			||||||
            #     tokenizer=tokenizer,
 | 
					 | 
				
			||||||
            # )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # finished sentences should add a padding token to next
 | 
					 | 
				
			||||||
            pad_token = pad_token_id * isFinished
 | 
					 | 
				
			||||||
            next_tokens = next_tokens * (1 - isFinished) + pad_token
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            isFinished = isFinished | next_tokens.eq(eos_token_id_tensor)
 | 
					 | 
				
			||||||
            if isFinished.min() == 1:  # all batch is finish
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
 | 
					 | 
				
			||||||
        return input_ids
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _reorder_cache(self, past, beam_idx):
 | 
					 | 
				
			||||||
        # if decoder past is not included in output
 | 
					 | 
				
			||||||
        # speedy decoding is disabled and no need to reorder
 | 
					 | 
				
			||||||
        if past is None:
 | 
					 | 
				
			||||||
            logger.warning(
 | 
					 | 
				
			||||||
                "You might want to consider setting `use_cache=True` to speed up decoding"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            return past
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        reordered_decoder_past = ()
 | 
					 | 
				
			||||||
        for layer_past_states in past:
 | 
					 | 
				
			||||||
            # get the correct batch idx from layer past batch dim
 | 
					 | 
				
			||||||
            # batch dim of `past` is at 2nd position
 | 
					 | 
				
			||||||
            reordered_layer_past_states = ()
 | 
					 | 
				
			||||||
            for layer_past_state in layer_past_states:
 | 
					 | 
				
			||||||
                # need to set correct `past` for each of the four key / value states
 | 
					 | 
				
			||||||
                reordered_layer_past_states = reordered_layer_past_states + (
 | 
					 | 
				
			||||||
                    layer_past_state.index_select(
 | 
					 | 
				
			||||||
                        0, beam_idx.to(layer_past_state.device)
 | 
					 | 
				
			||||||
                    ),
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
 | 
					 | 
				
			||||||
            assert len(reordered_layer_past_states) == len(layer_past_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            reordered_decoder_past = reordered_decoder_past + (
 | 
					 | 
				
			||||||
                reordered_layer_past_states,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        return reordered_decoder_past
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,33 +0,0 @@
 | 
				
			||||||
{
 | 
					 | 
				
			||||||
  "_name_or_path": "./bloomz_560m_pretrained",
 | 
					 | 
				
			||||||
  "apply_residual_connection_post_layernorm": false,
 | 
					 | 
				
			||||||
  "architectures": [
 | 
					 | 
				
			||||||
    "BloomForCausalLM"
 | 
					 | 
				
			||||||
  ],
 | 
					 | 
				
			||||||
  "attention_dropout": 0.0,
 | 
					 | 
				
			||||||
  "attention_softmax_in_fp32": true,
 | 
					 | 
				
			||||||
  "bias_dropout_fusion": true,
 | 
					 | 
				
			||||||
  "bos_token_id": 1,
 | 
					 | 
				
			||||||
  "eos_token_id": 2,
 | 
					 | 
				
			||||||
  "hidden_dropout": 0.0,
 | 
					 | 
				
			||||||
  "hidden_size": 1024,
 | 
					 | 
				
			||||||
  "initializer_range": 0.02,
 | 
					 | 
				
			||||||
  "layer_norm_epsilon": 1e-05,
 | 
					 | 
				
			||||||
  "masked_softmax_fusion": true,
 | 
					 | 
				
			||||||
  "model_type": "bloom",
 | 
					 | 
				
			||||||
  "n_head": 16,
 | 
					 | 
				
			||||||
  "n_inner": null,
 | 
					 | 
				
			||||||
  "n_layer": 24,
 | 
					 | 
				
			||||||
  "offset_alibi": 100,
 | 
					 | 
				
			||||||
  "pad_token_id": 3,
 | 
					 | 
				
			||||||
  "pretraining_tp": 1,
 | 
					 | 
				
			||||||
  "seq_length": 2048,
 | 
					 | 
				
			||||||
  "skip_bias_add": true,
 | 
					 | 
				
			||||||
  "skip_bias_add_qkv": false,
 | 
					 | 
				
			||||||
  "slow_but_exact": false,
 | 
					 | 
				
			||||||
  "torch_dtype": "float16",
 | 
					 | 
				
			||||||
  "transformers_version": "4.28.1",
 | 
					 | 
				
			||||||
  "unk_token_id": 0,
 | 
					 | 
				
			||||||
  "use_cache": true,
 | 
					 | 
				
			||||||
  "vocab_size": 250880
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,11 +0,0 @@
 | 
				
			||||||
{
 | 
					 | 
				
			||||||
    "framework": "pytorch",
 | 
					 | 
				
			||||||
    "task": "text-generation",
 | 
					 | 
				
			||||||
    "model": {
 | 
					 | 
				
			||||||
        "type": "bloom"
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    "pipeline": {
 | 
					 | 
				
			||||||
        "type": "seqgpt"
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    "allow_remote": true
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,242 +0,0 @@
 | 
				
			||||||
# coding=utf-8
 | 
					 | 
				
			||||||
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team.  All rights reserved.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
# You may obtain a copy of the License at
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
# limitations under the License.
 | 
					 | 
				
			||||||
""" Bloom configuration"""
 | 
					 | 
				
			||||||
from collections import OrderedDict
 | 
					 | 
				
			||||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from packaging import version
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if TYPE_CHECKING:
 | 
					 | 
				
			||||||
    from ... import PreTrainedTokenizer, TensorType
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from transformers.configuration_utils import PretrainedConfig
 | 
					 | 
				
			||||||
from transformers.onnx import OnnxConfigWithPast, PatchingSpec
 | 
					 | 
				
			||||||
from transformers.utils import is_torch_available, logging
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
 | 
					 | 
				
			||||||
    "bigscience/bloom": "https://huggingface.co/bigscience/bloom/resolve/main/config.json",
 | 
					 | 
				
			||||||
    "bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/config.json",
 | 
					 | 
				
			||||||
    "bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/config.json",
 | 
					 | 
				
			||||||
    "bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/config.json",
 | 
					 | 
				
			||||||
    "bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/config.json",
 | 
					 | 
				
			||||||
    "bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/config.json",
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class BloomConfig(PretrainedConfig):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    This is the configuration class to store the configuration of a [`BloomModel`]. It is used to instantiate a Bloom
 | 
					 | 
				
			||||||
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
 | 
					 | 
				
			||||||
    defaults will yield a similar configuration to the Bloom architecture
 | 
					 | 
				
			||||||
    [bigscience/bloom](https://huggingface.co/bigscience/bloom).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
 | 
					 | 
				
			||||||
    documentation from [`PretrainedConfig`] for more information.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        vocab_size (`int`, *optional*, defaults to 250880):
 | 
					 | 
				
			||||||
            Vocabulary size of the Bloom model. Defines the maximum number of different tokens that can be represented
 | 
					 | 
				
			||||||
            by the `inputs_ids` passed when calling [`BloomModel`]. Check [this
 | 
					 | 
				
			||||||
            discussion](https://huggingface.co/bigscience/bloom/discussions/120#633d28389addb8530b406c2a) on how the
 | 
					 | 
				
			||||||
            `vocab_size` has been defined.
 | 
					 | 
				
			||||||
        hidden_size (`int`, *optional*, defaults to 64):
 | 
					 | 
				
			||||||
            Dimensionality of the embeddings and hidden states.
 | 
					 | 
				
			||||||
        n_layer (`int`, *optional*, defaults to 2):
 | 
					 | 
				
			||||||
            Number of hidden layers in the Transformer encoder.
 | 
					 | 
				
			||||||
        n_head (`int`, *optional*, defaults to 8):
 | 
					 | 
				
			||||||
            Number of attention heads for each attention layer in the Transformer encoder.
 | 
					 | 
				
			||||||
        layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
 | 
					 | 
				
			||||||
            The epsilon to use in the layer normalization layers.
 | 
					 | 
				
			||||||
        initializer_range (`float`, *optional*, defaults to 0.02):
 | 
					 | 
				
			||||||
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
 | 
					 | 
				
			||||||
        apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`):
 | 
					 | 
				
			||||||
            If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
 | 
					 | 
				
			||||||
        hidden_dropout (`float`, *optional*, defaults to 0.1):
 | 
					 | 
				
			||||||
            Dropout rate of the dropout function on the bias dropout.
 | 
					 | 
				
			||||||
        attention_dropout (`float`, *optional*, defaults to 0.1):
 | 
					 | 
				
			||||||
            Dropout rate applied to the attention probs
 | 
					 | 
				
			||||||
        use_cache (`bool`, *optional*, defaults to `True`):
 | 
					 | 
				
			||||||
            Whether or not the model should return the last key/values attentions (not used by all models).
 | 
					 | 
				
			||||||
        pretraining_tp (`int`, *optional*, defaults to `1`):
 | 
					 | 
				
			||||||
            Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
 | 
					 | 
				
			||||||
            document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
 | 
					 | 
				
			||||||
            necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
 | 
					 | 
				
			||||||
            issue](https://github.com/pytorch/pytorch/issues/76232). Note also that this is enabled only when
 | 
					 | 
				
			||||||
            `slow_but_exact=True`.
 | 
					 | 
				
			||||||
        slow_but_exact (`bool`, *optional*, defaults to `False`):
 | 
					 | 
				
			||||||
            Experimental feature. Whether to use slow but exact implementation of the attention mechanism. While
 | 
					 | 
				
			||||||
            merging the TP rank tensors, due to slicing operations the results may be slightly different between the
 | 
					 | 
				
			||||||
            model trained on Megatron and our model. Please refer to [this
 | 
					 | 
				
			||||||
            issue](https://github.com/pytorch/pytorch/issues/76232). A solution to obtain more accurate results is to
 | 
					 | 
				
			||||||
            enable this feature. Enabling this will hurt the computational time of the inference. Will be probably
 | 
					 | 
				
			||||||
            resolved in the future once the main model has been fine-tuned with TP_rank=1.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Example:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ```python
 | 
					 | 
				
			||||||
    >>> from transformers import BloomConfig, BloomModel
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    >>> # Initializing a Bloom configuration
 | 
					 | 
				
			||||||
    >>> configuration = BloomConfig()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    >>> # Initializing a model (with random weights) from the configuration
 | 
					 | 
				
			||||||
    >>> model = BloomModel(configuration)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    >>> # Accessing the model configuration
 | 
					 | 
				
			||||||
    >>> configuration = model.config
 | 
					 | 
				
			||||||
    ```"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    model_type = "bloom"
 | 
					 | 
				
			||||||
    keys_to_ignore_at_inference = ["past_key_values"]
 | 
					 | 
				
			||||||
    attribute_map = {
 | 
					 | 
				
			||||||
        "num_hidden_layers": "n_layer",
 | 
					 | 
				
			||||||
        "num_attention_heads": "n_head",
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        vocab_size=250880,
 | 
					 | 
				
			||||||
        hidden_size=64,
 | 
					 | 
				
			||||||
        n_layer=2,
 | 
					 | 
				
			||||||
        n_head=8,
 | 
					 | 
				
			||||||
        layer_norm_epsilon=1e-5,
 | 
					 | 
				
			||||||
        initializer_range=0.02,
 | 
					 | 
				
			||||||
        use_cache=True,
 | 
					 | 
				
			||||||
        bos_token_id=1,
 | 
					 | 
				
			||||||
        eos_token_id=2,
 | 
					 | 
				
			||||||
        apply_residual_connection_post_layernorm=False,
 | 
					 | 
				
			||||||
        hidden_dropout=0.0,
 | 
					 | 
				
			||||||
        attention_dropout=0.0,
 | 
					 | 
				
			||||||
        pretraining_tp=1,  # TP rank used when training with megatron
 | 
					 | 
				
			||||||
        slow_but_exact=False,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        self.vocab_size = vocab_size
 | 
					 | 
				
			||||||
        # Backward compatibility with n_embed kwarg
 | 
					 | 
				
			||||||
        n_embed = kwargs.pop("n_embed", None)
 | 
					 | 
				
			||||||
        self.hidden_size = hidden_size if n_embed is None else n_embed
 | 
					 | 
				
			||||||
        self.n_layer = n_layer
 | 
					 | 
				
			||||||
        self.n_head = n_head
 | 
					 | 
				
			||||||
        self.layer_norm_epsilon = layer_norm_epsilon
 | 
					 | 
				
			||||||
        self.initializer_range = initializer_range
 | 
					 | 
				
			||||||
        self.use_cache = use_cache
 | 
					 | 
				
			||||||
        self.pretraining_tp = pretraining_tp
 | 
					 | 
				
			||||||
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
 | 
					 | 
				
			||||||
        self.hidden_dropout = hidden_dropout
 | 
					 | 
				
			||||||
        self.attention_dropout = attention_dropout
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.bos_token_id = bos_token_id
 | 
					 | 
				
			||||||
        self.eos_token_id = eos_token_id
 | 
					 | 
				
			||||||
        self.slow_but_exact = slow_but_exact
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class BloomOnnxConfig(OnnxConfigWithPast):
 | 
					 | 
				
			||||||
    torch_onnx_minimum_version = version.parse("1.12")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        config: PretrainedConfig,
 | 
					 | 
				
			||||||
        task: str = "default",
 | 
					 | 
				
			||||||
        patching_specs: List[PatchingSpec] = None,
 | 
					 | 
				
			||||||
        use_past: bool = False,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
 | 
					 | 
				
			||||||
        if not getattr(self._config, "pad_token_id", None):
 | 
					 | 
				
			||||||
            # TODO: how to do that better?
 | 
					 | 
				
			||||||
            self._config.pad_token_id = 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def inputs(self) -> Mapping[str, Mapping[int, str]]:
 | 
					 | 
				
			||||||
        common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
 | 
					 | 
				
			||||||
        if self.use_past:
 | 
					 | 
				
			||||||
            # BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
 | 
					 | 
				
			||||||
            self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
 | 
					 | 
				
			||||||
            common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return common_inputs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def num_layers(self) -> int:
 | 
					 | 
				
			||||||
        return self._config.n_layer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def num_attention_heads(self) -> int:
 | 
					 | 
				
			||||||
        return self._config.n_head
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def atol_for_validation(self) -> float:
 | 
					 | 
				
			||||||
        return 1e-3
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def generate_dummy_inputs(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        tokenizer: "PreTrainedTokenizer",
 | 
					 | 
				
			||||||
        batch_size: int = -1,
 | 
					 | 
				
			||||||
        seq_length: int = -1,
 | 
					 | 
				
			||||||
        is_pair: bool = False,
 | 
					 | 
				
			||||||
        framework: Optional["TensorType"] = None,
 | 
					 | 
				
			||||||
    ) -> Mapping[str, Any]:
 | 
					 | 
				
			||||||
        common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
 | 
					 | 
				
			||||||
            tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # We need to order the input in the way they appears in the forward()
 | 
					 | 
				
			||||||
        ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Need to add the past_keys
 | 
					 | 
				
			||||||
        if self.use_past:
 | 
					 | 
				
			||||||
            if not is_torch_available():
 | 
					 | 
				
			||||||
                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                batch, seqlen = common_inputs["input_ids"].shape
 | 
					 | 
				
			||||||
                # Not using the same length for past_key_values
 | 
					 | 
				
			||||||
                past_key_values_length = seqlen + 2
 | 
					 | 
				
			||||||
                head_dim = self._config.hidden_size // self.num_attention_heads
 | 
					 | 
				
			||||||
                past_key_shape = (
 | 
					 | 
				
			||||||
                    batch * self.num_attention_heads,
 | 
					 | 
				
			||||||
                    head_dim,
 | 
					 | 
				
			||||||
                    past_key_values_length,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                past_value_shape = (
 | 
					 | 
				
			||||||
                    batch * self.num_attention_heads,
 | 
					 | 
				
			||||||
                    past_key_values_length,
 | 
					 | 
				
			||||||
                    head_dim,
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                ordered_inputs["past_key_values"] = [
 | 
					 | 
				
			||||||
                    (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
 | 
					 | 
				
			||||||
                ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
 | 
					 | 
				
			||||||
        if self.use_past:
 | 
					 | 
				
			||||||
            mask_dtype = ordered_inputs["attention_mask"].dtype
 | 
					 | 
				
			||||||
            ordered_inputs["attention_mask"] = torch.cat(
 | 
					 | 
				
			||||||
                [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return ordered_inputs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def default_onnx_opset(self) -> int:
 | 
					 | 
				
			||||||
        return 13
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,255 +0,0 @@
 | 
				
			||||||
# coding=utf-8
 | 
					 | 
				
			||||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
# You may obtain a copy of the License at
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
# limitations under the License.
 | 
					 | 
				
			||||||
"""Convert BigScience BLOOM checkpoint."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import argparse
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import re
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from transformers import BloomConfig, BloomModel
 | 
					 | 
				
			||||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
 | 
					 | 
				
			||||||
from transformers.utils import logging
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logging.set_verbosity_info()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
 | 
					 | 
				
			||||||
    "word_embeddings_layernorm.weight",
 | 
					 | 
				
			||||||
    "word_embeddings_layernorm.bias",
 | 
					 | 
				
			||||||
    "input_layernorm.weight",
 | 
					 | 
				
			||||||
    "input_layernorm.bias",
 | 
					 | 
				
			||||||
    "post_attention_layernorm.weight",
 | 
					 | 
				
			||||||
    "post_attention_layernorm.bias",
 | 
					 | 
				
			||||||
    "self_attention.dense.bias",
 | 
					 | 
				
			||||||
    "mlp.dense_4h_to_h.bias",
 | 
					 | 
				
			||||||
    "ln_f.weight",
 | 
					 | 
				
			||||||
    "ln_f.bias",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
 | 
					 | 
				
			||||||
    "mlp.dense_4h_to_h.weight",
 | 
					 | 
				
			||||||
    "self_attention.dense.weight",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def layer_name_mapping(key, file):
 | 
					 | 
				
			||||||
    """Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
 | 
					 | 
				
			||||||
    # Handle first and last layers
 | 
					 | 
				
			||||||
    layer_rename_map = {
 | 
					 | 
				
			||||||
        "word_embeddings.weight": "word_embeddings.weight",
 | 
					 | 
				
			||||||
        "word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
 | 
					 | 
				
			||||||
        "word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
 | 
					 | 
				
			||||||
        "weight": "ln_f.weight",
 | 
					 | 
				
			||||||
        "bias": "ln_f.bias",
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if key in layer_rename_map:
 | 
					 | 
				
			||||||
        return layer_rename_map[key]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Handle transformer blocks
 | 
					 | 
				
			||||||
    layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
 | 
					 | 
				
			||||||
    layer_number -= 3
 | 
					 | 
				
			||||||
    return f"h.{layer_number}." + key
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_dtype_size(dtype):
 | 
					 | 
				
			||||||
    if dtype == torch.bool:
 | 
					 | 
				
			||||||
        return 1 / 8
 | 
					 | 
				
			||||||
    bit_search = re.search(r"[^\d](\d+)$", str(dtype))
 | 
					 | 
				
			||||||
    if bit_search is None:
 | 
					 | 
				
			||||||
        raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
 | 
					 | 
				
			||||||
    bit_size = int(bit_search.groups()[0])
 | 
					 | 
				
			||||||
    return bit_size // 8
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def convert_bloom_checkpoint_to_pytorch(
 | 
					 | 
				
			||||||
    bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    # Construct model
 | 
					 | 
				
			||||||
    if bloom_config_file == "":
 | 
					 | 
				
			||||||
        config = BloomConfig()
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        config = BloomConfig.from_json_file(bloom_config_file)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if shard_model:
 | 
					 | 
				
			||||||
        file_names = os.listdir(bloom_checkpoint_path)
 | 
					 | 
				
			||||||
        file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        index_dict = {"weight_map": {}, "metadata": {}}
 | 
					 | 
				
			||||||
        total_size = 0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        missing_keys = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        config = BloomConfig()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for j, file in enumerate(file_names):
 | 
					 | 
				
			||||||
            print("Processing file: {}".format(file))
 | 
					 | 
				
			||||||
            tensors = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for i in range(pretraining_tp):
 | 
					 | 
				
			||||||
                # load all TP files
 | 
					 | 
				
			||||||
                f_name = file.replace("model_00", f"model_0{i}")
 | 
					 | 
				
			||||||
                temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Rename keys in the transformers names
 | 
					 | 
				
			||||||
                keys = list(temp.keys())
 | 
					 | 
				
			||||||
                for key in keys:
 | 
					 | 
				
			||||||
                    temp[layer_name_mapping(key, file)] = temp.pop(key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if tensors is None:
 | 
					 | 
				
			||||||
                    tensors = temp
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    for key in tensors.keys():
 | 
					 | 
				
			||||||
                        if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
					 | 
				
			||||||
                            # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
 | 
					 | 
				
			||||||
                            tensors[key] += temp[key]
 | 
					 | 
				
			||||||
                        else:
 | 
					 | 
				
			||||||
                            # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
 | 
					 | 
				
			||||||
                            cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
 | 
					 | 
				
			||||||
                            # We concatenate these weights accross TP ranks
 | 
					 | 
				
			||||||
                            tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Divide by the number of TP the weights we want to average
 | 
					 | 
				
			||||||
            for key in tensors.keys():
 | 
					 | 
				
			||||||
                if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
					 | 
				
			||||||
                    tensors[key] = tensors[key] / pretraining_tp
 | 
					 | 
				
			||||||
            torch.save(
 | 
					 | 
				
			||||||
                tensors,
 | 
					 | 
				
			||||||
                os.path.join(
 | 
					 | 
				
			||||||
                    pytorch_dump_folder_path,
 | 
					 | 
				
			||||||
                    "pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
 | 
					 | 
				
			||||||
                ),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for key in tensors.keys():
 | 
					 | 
				
			||||||
                value = tensors[key]
 | 
					 | 
				
			||||||
                total_size += value.numel() * get_dtype_size(value.dtype)
 | 
					 | 
				
			||||||
                if key not in index_dict["weight_map"]:
 | 
					 | 
				
			||||||
                    index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
 | 
					 | 
				
			||||||
                        str(j + 1).zfill(5), str(len(file_names)).zfill(5)
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        config = BloomConfig()
 | 
					 | 
				
			||||||
        pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
 | 
					 | 
				
			||||||
        index_dict["metadata"]["total_size"] = total_size
 | 
					 | 
				
			||||||
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
 | 
					 | 
				
			||||||
            f.write(config.to_json_string())
 | 
					 | 
				
			||||||
        with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
 | 
					 | 
				
			||||||
            json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
 | 
					 | 
				
			||||||
            f.write(json_config)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        model = BloomModel(config)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        file_names = os.listdir(bloom_checkpoint_path)
 | 
					 | 
				
			||||||
        file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        missing_keys = None
 | 
					 | 
				
			||||||
        for i, file in enumerate(file_names):
 | 
					 | 
				
			||||||
            tensors = None
 | 
					 | 
				
			||||||
            for i in range(pretraining_tp):
 | 
					 | 
				
			||||||
                # load all TP files
 | 
					 | 
				
			||||||
                f_name = file.replace("model_00", f"model_0{i}")
 | 
					 | 
				
			||||||
                temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Rename keys in the transformers names
 | 
					 | 
				
			||||||
                keys = list(temp.keys())
 | 
					 | 
				
			||||||
                for key in keys:
 | 
					 | 
				
			||||||
                    temp[layer_name_mapping(key, file)] = temp.pop(key)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if tensors is None:
 | 
					 | 
				
			||||||
                    tensors = temp
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    for key in tensors.keys():
 | 
					 | 
				
			||||||
                        # We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
 | 
					 | 
				
			||||||
                        if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
					 | 
				
			||||||
                            tensors[key] += temp[key]
 | 
					 | 
				
			||||||
                        else:
 | 
					 | 
				
			||||||
                            # Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
 | 
					 | 
				
			||||||
                            cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
 | 
					 | 
				
			||||||
                            # We concatenate these weights accross TP ranks
 | 
					 | 
				
			||||||
                            tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            # Divide by the number of TP the weights we want to average
 | 
					 | 
				
			||||||
            for key in tensors.keys():
 | 
					 | 
				
			||||||
                if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
 | 
					 | 
				
			||||||
                    tensors[key] = tensors[key] / pretraining_tp
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            other_keys = model.load_state_dict(tensors, strict=False)
 | 
					 | 
				
			||||||
            assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
 | 
					 | 
				
			||||||
            if missing_keys is None:
 | 
					 | 
				
			||||||
                missing_keys = set(other_keys.missing_keys)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        assert not missing_keys, f"The keys {missing_keys} are missing"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Save pytorch-model
 | 
					 | 
				
			||||||
        os.makedirs(pytorch_dump_folder_path, exist_ok=True)
 | 
					 | 
				
			||||||
        pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
 | 
					 | 
				
			||||||
        pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
 | 
					 | 
				
			||||||
        print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
 | 
					 | 
				
			||||||
        if config.torch_dtype is not None:
 | 
					 | 
				
			||||||
            model = model.to(config.torch_dtype)
 | 
					 | 
				
			||||||
        torch.save(model.state_dict(), pytorch_weights_dump_path)
 | 
					 | 
				
			||||||
        print(f"Save configuration file to {pytorch_config_dump_path}")
 | 
					 | 
				
			||||||
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
 | 
					 | 
				
			||||||
            f.write(config.to_json_string())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					 | 
				
			||||||
    # Required parameters
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--bloom_checkpoint_path",
 | 
					 | 
				
			||||||
        default=None,
 | 
					 | 
				
			||||||
        type=str,
 | 
					 | 
				
			||||||
        required=True,
 | 
					 | 
				
			||||||
        help="Path to the Megatron-LM checkpoint path.",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--bloom_config_file",
 | 
					 | 
				
			||||||
        default="",
 | 
					 | 
				
			||||||
        type=str,
 | 
					 | 
				
			||||||
        help=(
 | 
					 | 
				
			||||||
            "An optional config json file corresponding to the pre-trained model. \n"
 | 
					 | 
				
			||||||
            "This specifies the model architecture."
 | 
					 | 
				
			||||||
        ),
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--shard_model",
 | 
					 | 
				
			||||||
        action="store_true",
 | 
					 | 
				
			||||||
        help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    parser.add_argument(
 | 
					 | 
				
			||||||
        "--pretraining_tp",
 | 
					 | 
				
			||||||
        default=4,
 | 
					 | 
				
			||||||
        type=int,
 | 
					 | 
				
			||||||
        help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    args = parser.parse_args()
 | 
					 | 
				
			||||||
    convert_bloom_checkpoint_to_pytorch(
 | 
					 | 
				
			||||||
        args.bloom_checkpoint_path,
 | 
					 | 
				
			||||||
        args.bloom_config_file,
 | 
					 | 
				
			||||||
        args.pytorch_dump_folder_path,
 | 
					 | 
				
			||||||
        args.shard_model,
 | 
					 | 
				
			||||||
        args.pretraining_tp,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,69 +0,0 @@
 | 
				
			||||||
# from modelscope.utils.constant import Tasks
 | 
					 | 
				
			||||||
# from modelscope.pipelines import pipeline
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# prompt = "输入: 中国的首都在哪里\n输出: "
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# # task可选值为 抽取、分类。text为需要分析的文本。labels为类型列表,中文逗号分隔。
 | 
					 | 
				
			||||||
# inputs = {'task': '抽取', 'text': '杭州欢迎你。', 'labels': '地名'}
 | 
					 | 
				
			||||||
# # PROMPT_TEMPLATE保持不变
 | 
					 | 
				
			||||||
# PROMPT_TEMPLATE = '输入: {text}\n{task}: {labels}\n输出: '
 | 
					 | 
				
			||||||
# # prompt = PROMPT_TEMPLATE.format(**inputs)
 | 
					 | 
				
			||||||
# pipeline_ins = pipeline(task=Tasks.text_generation, model='damo/nlp_seqgpt-560m', model_revision = 'v1.0.1', run_kwargs={'gen_token': '[GEN]'})
 | 
					 | 
				
			||||||
# print(pipeline_ins(prompt))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
from modelscope.pipelines import pipeline
 | 
					 | 
				
			||||||
from modelscope.utils.constant import Tasks
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from modelscope.utils.config import Config
 | 
					 | 
				
			||||||
from modelscope import snapshot_download
 | 
					 | 
				
			||||||
from transformers import AutoConfig
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from modeling_bloom import BloomForCausalLM
 | 
					 | 
				
			||||||
from tokenization_bloom_fast import BloomTokenizerFast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
seed = 4321
 | 
					 | 
				
			||||||
torch.manual_seed(seed)
 | 
					 | 
				
			||||||
torch.cuda.manual_seed_all(seed)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
model_dir = snapshot_download("damo/nlp_seqgpt-560m")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
config, kwargs = AutoConfig.from_pretrained(
 | 
					 | 
				
			||||||
    model_dir,
 | 
					 | 
				
			||||||
    return_unused_kwargs=True,
 | 
					 | 
				
			||||||
    trust_remote_code=True,
 | 
					 | 
				
			||||||
    code_revision=None,
 | 
					 | 
				
			||||||
    _commit_hash=None,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
tokenizer_config_file = "./tokenizer_config.json"
 | 
					 | 
				
			||||||
if tokenizer_config_file is not None:
 | 
					 | 
				
			||||||
    with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
 | 
					 | 
				
			||||||
        init_kwargs = json.load(tokenizer_config_handle)
 | 
					 | 
				
			||||||
    init_kwargs.pop("tokenizer_class", None)
 | 
					 | 
				
			||||||
    init_kwargs.pop("tokenizer_file", None)
 | 
					 | 
				
			||||||
    saved_init_inputs = init_kwargs.pop("init_inputs", ())
 | 
					 | 
				
			||||||
    init_inputs = saved_init_inputs
 | 
					 | 
				
			||||||
init_kwargs["vocab_file"] = None
 | 
					 | 
				
			||||||
init_kwargs["added_tokens_file"] = None
 | 
					 | 
				
			||||||
init_kwargs["special_tokens_map_file"] = "./special_tokens_map.json"
 | 
					 | 
				
			||||||
init_kwargs["tokenizer_file"] = "./tokenizer.json"
 | 
					 | 
				
			||||||
init_kwargs["name_or_path"] = model_dir
 | 
					 | 
				
			||||||
tokenizer = BloomTokenizerFast(*init_inputs, **init_kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
model = BloomForCausalLM(config)
 | 
					 | 
				
			||||||
model = model.from_pretrained(model_dir).cuda().train()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
prompt = "输入: 中国的首都在哪里\n输出: "
 | 
					 | 
				
			||||||
prompt = "输入: 美国的首都在哪里\n输出: "
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
 | 
					 | 
				
			||||||
input_ids = input_ids.input_ids.cuda()
 | 
					 | 
				
			||||||
outputs = model.generate(input_ids, num_beams=4, do_sample=False, max_new_tokens=256)
 | 
					 | 
				
			||||||
decoded_sentences = tokenizer.batch_decode(outputs, skip_special_tokens=True)
 | 
					 | 
				
			||||||
print(decoded_sentences[0])
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,7 +0,0 @@
 | 
				
			||||||
{
 | 
					 | 
				
			||||||
  "_from_model_config": true,
 | 
					 | 
				
			||||||
  "bos_token_id": 1,
 | 
					 | 
				
			||||||
  "eos_token_id": 2,
 | 
					 | 
				
			||||||
  "pad_token_id": 3,
 | 
					 | 
				
			||||||
  "transformers_version": "4.28.1"
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
					@ -1,734 +0,0 @@
 | 
				
			||||||
# coding=utf-8
 | 
					 | 
				
			||||||
# Copyright 2023 HuggingFace Inc. Team and Bigscience Workshop. All rights reserved.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
# You may obtain a copy of the License at
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
# limitations under the License.
 | 
					 | 
				
			||||||
"""Flax BLOOM model."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import math
 | 
					 | 
				
			||||||
from functools import partial
 | 
					 | 
				
			||||||
from typing import Optional, Tuple
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import flax.linen as nn
 | 
					 | 
				
			||||||
import jax
 | 
					 | 
				
			||||||
import jax.numpy as jnp
 | 
					 | 
				
			||||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
 | 
					 | 
				
			||||||
from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask
 | 
					 | 
				
			||||||
from flax.linen.activation import tanh
 | 
					 | 
				
			||||||
from flax.traverse_util import flatten_dict, unflatten_dict
 | 
					 | 
				
			||||||
from jax import lax
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ...modeling_flax_outputs import (
 | 
					 | 
				
			||||||
    FlaxBaseModelOutput,
 | 
					 | 
				
			||||||
    FlaxBaseModelOutputWithPastAndCrossAttentions,
 | 
					 | 
				
			||||||
    FlaxCausalLMOutput,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring
 | 
					 | 
				
			||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
 | 
					 | 
				
			||||||
from .configuration_bloom import BloomConfig
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
_CHECKPOINT_FOR_DOC = "bigscience/bloom"
 | 
					 | 
				
			||||||
_CONFIG_FOR_DOC = "BloomConfig"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
BLOOM_START_DOCSTRING = r"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
 | 
					 | 
				
			||||||
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 | 
					 | 
				
			||||||
    etc.)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This model is also a Flax Linen
 | 
					 | 
				
			||||||
    [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
 | 
					 | 
				
			||||||
    regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Finally, this model supports inherent JAX features such as:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
 | 
					 | 
				
			||||||
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
 | 
					 | 
				
			||||||
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
 | 
					 | 
				
			||||||
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Parameters:
 | 
					 | 
				
			||||||
        config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
 | 
					 | 
				
			||||||
            Initializing with a config file does not load the weights associated with the model, only the
 | 
					 | 
				
			||||||
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
 | 
					 | 
				
			||||||
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
 | 
					 | 
				
			||||||
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
 | 
					 | 
				
			||||||
            `jax.numpy.bfloat16` (on TPUs).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
 | 
					 | 
				
			||||||
            specified all the computation will be performed with the given `dtype`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
 | 
					 | 
				
			||||||
            parameters.**
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
 | 
					 | 
				
			||||||
            [`~FlaxPreTrainedModel.to_bf16`].
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
BLOOM_INPUTS_DOCSTRING = r"""
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
 | 
					 | 
				
			||||||
            `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
 | 
					 | 
				
			||||||
            [`PreTrainedTokenizer.__call__`] for details.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            [What are input IDs?](../glossary#input-ids)
 | 
					 | 
				
			||||||
        attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
 | 
					 | 
				
			||||||
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            - 1 for tokens that are **not masked**,
 | 
					 | 
				
			||||||
            - 0 for tokens that are **masked**.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            [What are attention masks?](../glossary#attention-mask)
 | 
					 | 
				
			||||||
        past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
 | 
					 | 
				
			||||||
            Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
 | 
					 | 
				
			||||||
            auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
 | 
					 | 
				
			||||||
        output_attentions (`bool`, *optional*):
 | 
					 | 
				
			||||||
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
 | 
					 | 
				
			||||||
            tensors for more detail.
 | 
					 | 
				
			||||||
        output_hidden_states (`bool`, *optional*):
 | 
					 | 
				
			||||||
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
 | 
					 | 
				
			||||||
            more detail.
 | 
					 | 
				
			||||||
        return_dict (`bool`, *optional*):
 | 
					 | 
				
			||||||
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def build_alibi_tensor(attention_mask: jnp.ndarray, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it
 | 
					 | 
				
			||||||
    relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
 | 
					 | 
				
			||||||
    `softmax(l+a) = softmax(l)`. Based on
 | 
					 | 
				
			||||||
    https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
 | 
					 | 
				
			||||||
    Link to paper: https://arxiv.org/abs/2108.12409
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        attention_mask (`jnp.ndarray`):
 | 
					 | 
				
			||||||
            Token-wise attention mask, this should be of shape `(batch_size, max_seq_len)`.
 | 
					 | 
				
			||||||
        num_heads (`int`):
 | 
					 | 
				
			||||||
            Number of attention heads.
 | 
					 | 
				
			||||||
        dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
 | 
					 | 
				
			||||||
            The data type (dtype) of the output tensor.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Returns: Alibi tensor of shape `(batch_size * num_heads, 1, max_seq_len)`.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    batch_size, seq_length = attention_mask.shape
 | 
					 | 
				
			||||||
    closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
 | 
					 | 
				
			||||||
    base = jnp.array(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=jnp.float32)
 | 
					 | 
				
			||||||
    powers = jnp.arange(1, 1 + closest_power_of_2, dtype=jnp.float32)
 | 
					 | 
				
			||||||
    slopes = jax.lax.pow(base, powers)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if closest_power_of_2 != num_heads:
 | 
					 | 
				
			||||||
        extra_base = jnp.array(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=jnp.float32)
 | 
					 | 
				
			||||||
        num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
 | 
					 | 
				
			||||||
        extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32)
 | 
					 | 
				
			||||||
        slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Note: the Alibi tensor will added to the attention bias that will be applied to the query, key product of attention
 | 
					 | 
				
			||||||
    # therefore, Alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
 | 
					 | 
				
			||||||
    # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
 | 
					 | 
				
			||||||
    # so that the query_length dimension will then be broadcast correctly.
 | 
					 | 
				
			||||||
    # This is more or less identical to T5's relative position bias:
 | 
					 | 
				
			||||||
    # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
 | 
					 | 
				
			||||||
    arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :]
 | 
					 | 
				
			||||||
    alibi = slopes[..., None] * arange_tensor
 | 
					 | 
				
			||||||
    alibi = jnp.expand_dims(alibi, axis=2)
 | 
					 | 
				
			||||||
    return jnp.asarray(alibi, dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FlaxBloomAttention(nn.Module):
 | 
					 | 
				
			||||||
    config: BloomConfig
 | 
					 | 
				
			||||||
    dtype: jnp.dtype = jnp.float32
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setup(self):
 | 
					 | 
				
			||||||
        self.hidden_size = self.config.hidden_size
 | 
					 | 
				
			||||||
        self.num_heads = self.config.n_head
 | 
					 | 
				
			||||||
        self.head_dim = self.hidden_size // self.num_heads
 | 
					 | 
				
			||||||
        self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.head_dim * self.num_heads != self.hidden_size:
 | 
					 | 
				
			||||||
            raise ValueError(
 | 
					 | 
				
			||||||
                f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and "
 | 
					 | 
				
			||||||
                f"`num_heads`: {self.num_heads})."
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        dense = partial(
 | 
					 | 
				
			||||||
            nn.Dense,
 | 
					 | 
				
			||||||
            dtype=self.dtype,
 | 
					 | 
				
			||||||
            kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.query_key_value = dense(self.hidden_size * 3)
 | 
					 | 
				
			||||||
        self.dense = dense(self.hidden_size)
 | 
					 | 
				
			||||||
        self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _split_heads(self, hidden_states):
 | 
					 | 
				
			||||||
        return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, self.head_dim * 3))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _merge_heads(self, hidden_states):
 | 
					 | 
				
			||||||
        return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @nn.compact
 | 
					 | 
				
			||||||
    # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache
 | 
					 | 
				
			||||||
    def _concatenate_to_cache(self, key, value, query, attention_mask):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        This function takes projected key, value states from a single input token and concatenates the states to cached
 | 
					 | 
				
			||||||
        states from previous steps. This function is slighly adapted from the official Flax repository:
 | 
					 | 
				
			||||||
        https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # detect if we're initializing by absence of existing cache data.
 | 
					 | 
				
			||||||
        is_initialized = self.has_variable("cache", "cached_key")
 | 
					 | 
				
			||||||
        cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
 | 
					 | 
				
			||||||
        cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
 | 
					 | 
				
			||||||
        cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if is_initialized:
 | 
					 | 
				
			||||||
            *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
 | 
					 | 
				
			||||||
            # update key, value caches with our new 1d spatial slices
 | 
					 | 
				
			||||||
            cur_index = cache_index.value
 | 
					 | 
				
			||||||
            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
 | 
					 | 
				
			||||||
            key = lax.dynamic_update_slice(cached_key.value, key, indices)
 | 
					 | 
				
			||||||
            value = lax.dynamic_update_slice(cached_value.value, value, indices)
 | 
					 | 
				
			||||||
            cached_key.value = key
 | 
					 | 
				
			||||||
            cached_value.value = value
 | 
					 | 
				
			||||||
            num_updated_cache_vectors = query.shape[1]
 | 
					 | 
				
			||||||
            cache_index.value = cache_index.value + num_updated_cache_vectors
 | 
					 | 
				
			||||||
            # causal mask for cached decoder self-attention: our single query position should only attend to those key
 | 
					 | 
				
			||||||
            # positions that have already been generated and cached, not the remaining zero elements.
 | 
					 | 
				
			||||||
            pad_mask = jnp.broadcast_to(
 | 
					 | 
				
			||||||
                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
 | 
					 | 
				
			||||||
                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            attention_mask = combine_masks(pad_mask, attention_mask)
 | 
					 | 
				
			||||||
        return key, value, attention_mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        hidden_states,
 | 
					 | 
				
			||||||
        residual,
 | 
					 | 
				
			||||||
        alibi,
 | 
					 | 
				
			||||||
        attention_mask=None,
 | 
					 | 
				
			||||||
        deterministic: bool = True,
 | 
					 | 
				
			||||||
        init_cache: bool = False,
 | 
					 | 
				
			||||||
        output_attentions: bool = False,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        batch_size, seq_length = hidden_states.shape[:2]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # proj q, k, v
 | 
					 | 
				
			||||||
        fused_qkv = self.query_key_value(hidden_states)
 | 
					 | 
				
			||||||
        fused_qkv = self._split_heads(fused_qkv)
 | 
					 | 
				
			||||||
        query, key, value = jnp.split(fused_qkv, 3, axis=-1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # for fast decoding causal attention mask should be shifted
 | 
					 | 
				
			||||||
        causal_attention_mask_shift = (
 | 
					 | 
				
			||||||
            self.variables["cache"]["cache_index"] if self.has_variable("cache", "cached_key") else 0
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # fast decoding for generate requires special attention_mask
 | 
					 | 
				
			||||||
        if self.has_variable("cache", "cached_key"):
 | 
					 | 
				
			||||||
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
 | 
					 | 
				
			||||||
            causal_attention_mask = jax.lax.dynamic_slice(
 | 
					 | 
				
			||||||
                causal_attention_mask,
 | 
					 | 
				
			||||||
                (0, 0, causal_attention_mask_shift, 0),
 | 
					 | 
				
			||||||
                (1, 1, seq_length, max_decoder_length),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # broadcast causal attention mask & attention mask to fit for merge
 | 
					 | 
				
			||||||
        causal_attention_mask = jnp.broadcast_to(
 | 
					 | 
				
			||||||
            causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape)
 | 
					 | 
				
			||||||
        attention_mask = combine_masks(attention_mask, causal_attention_mask)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        dropout_rng = None
 | 
					 | 
				
			||||||
        if not deterministic and self.config.attention_dropout > 0.0:
 | 
					 | 
				
			||||||
            dropout_rng = self.make_rng("dropout")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # During fast autoregressive decoding, we feed one position at a time,
 | 
					 | 
				
			||||||
        # and cache the keys and values step by step.
 | 
					 | 
				
			||||||
        if self.has_variable("cache", "cached_key") or init_cache:
 | 
					 | 
				
			||||||
            key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # transform boolean mask into float mask
 | 
					 | 
				
			||||||
        mask_value = jnp.finfo(self.dtype).min
 | 
					 | 
				
			||||||
        attention_bias = lax.select(
 | 
					 | 
				
			||||||
            attention_mask > 0,
 | 
					 | 
				
			||||||
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
 | 
					 | 
				
			||||||
            jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attention_bias = attention_bias + alibi
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Cast in fp32 if the original dtype is different from fp32
 | 
					 | 
				
			||||||
        attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attn_weights = dot_product_attention_weights(
 | 
					 | 
				
			||||||
            query,
 | 
					 | 
				
			||||||
            key,
 | 
					 | 
				
			||||||
            bias=attention_bias,
 | 
					 | 
				
			||||||
            dropout_rng=dropout_rng,
 | 
					 | 
				
			||||||
            dropout_rate=self.config.attention_dropout,
 | 
					 | 
				
			||||||
            deterministic=deterministic,
 | 
					 | 
				
			||||||
            dtype=attention_dtype,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Cast back in the original dtype if the native dtype is not fp32
 | 
					 | 
				
			||||||
        if self.attention_softmax_in_fp32:
 | 
					 | 
				
			||||||
            attn_weights = attn_weights.astype(self.dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
 | 
					 | 
				
			||||||
        attn_output = self._merge_heads(attn_output)
 | 
					 | 
				
			||||||
        attn_output = self.dense(attn_output)
 | 
					 | 
				
			||||||
        attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attn_output = attn_output + residual
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
 | 
					 | 
				
			||||||
        return outputs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class BloomGELU(nn.Module):
 | 
					 | 
				
			||||||
    def setup(self):
 | 
					 | 
				
			||||||
        self.dtype = jnp.float32
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(self, x):
 | 
					 | 
				
			||||||
        return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FlaxBloomMLP(nn.Module):
 | 
					 | 
				
			||||||
    config: BloomConfig
 | 
					 | 
				
			||||||
    dtype: jnp.dtype = jnp.float32
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setup(self):
 | 
					 | 
				
			||||||
        hidden_size = self.config.hidden_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init)
 | 
					 | 
				
			||||||
        self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init)
 | 
					 | 
				
			||||||
        self.hidden_dropout = nn.Dropout(self.config.hidden_dropout)
 | 
					 | 
				
			||||||
        self.act = BloomGELU()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(self, hidden_states, residual, deterministic: bool = True):
 | 
					 | 
				
			||||||
        hidden_states = self.dense_h_to_4h(hidden_states)
 | 
					 | 
				
			||||||
        hidden_states = self.act(hidden_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        intermediate_output = self.dense_4h_to_h(hidden_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        intermediate_output = intermediate_output + residual
 | 
					 | 
				
			||||||
        hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return hidden_states
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FlaxBloomBlock(nn.Module):
 | 
					 | 
				
			||||||
    config: BloomConfig
 | 
					 | 
				
			||||||
    dtype: jnp.dtype = jnp.float32
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setup(self):
 | 
					 | 
				
			||||||
        self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype)
 | 
					 | 
				
			||||||
        self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm
 | 
					 | 
				
			||||||
        self.hidden_dropout = self.config.hidden_dropout
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        hidden_states,
 | 
					 | 
				
			||||||
        alibi,
 | 
					 | 
				
			||||||
        attention_mask=None,
 | 
					 | 
				
			||||||
        deterministic: bool = True,
 | 
					 | 
				
			||||||
        init_cache: bool = False,
 | 
					 | 
				
			||||||
        output_attentions: bool = False,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        layernorm_output = self.input_layernorm(hidden_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # layer norm before saving residual if config calls for it
 | 
					 | 
				
			||||||
        if self.apply_residual_connection_post_layernorm:
 | 
					 | 
				
			||||||
            residual = layernorm_output
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            residual = hidden_states
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # self-attention
 | 
					 | 
				
			||||||
        attn_outputs = self.self_attention(
 | 
					 | 
				
			||||||
            layernorm_output,
 | 
					 | 
				
			||||||
            residual=residual,
 | 
					 | 
				
			||||||
            alibi=alibi,
 | 
					 | 
				
			||||||
            attention_mask=attention_mask,
 | 
					 | 
				
			||||||
            deterministic=deterministic,
 | 
					 | 
				
			||||||
            init_cache=init_cache,
 | 
					 | 
				
			||||||
            output_attentions=output_attentions,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attention_output = attn_outputs[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        outputs = attn_outputs[1:]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        post_layernorm = self.post_attention_layernorm(attention_output)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # set residual based on config
 | 
					 | 
				
			||||||
        if self.apply_residual_connection_post_layernorm:
 | 
					 | 
				
			||||||
            residual = post_layernorm
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            residual = attention_output
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        output = self.mlp(post_layernorm, residual, deterministic=deterministic)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        outputs = (output,) + outputs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return outputs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FlaxBloomPreTrainedModel(FlaxPreTrainedModel):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
 | 
					 | 
				
			||||||
    models.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    config_class = BloomConfig
 | 
					 | 
				
			||||||
    base_model_prefix = "transformer"
 | 
					 | 
				
			||||||
    module_class: nn.Module = None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        config: BloomConfig,
 | 
					 | 
				
			||||||
        input_shape: Tuple = (1, 1),
 | 
					 | 
				
			||||||
        seed: int = 0,
 | 
					 | 
				
			||||||
        dtype: jnp.dtype = jnp.float32,
 | 
					 | 
				
			||||||
        _do_init: bool = True,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        module = self.module_class(config=config, dtype=dtype, **kwargs)
 | 
					 | 
				
			||||||
        super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
 | 
					 | 
				
			||||||
        # init input tensors
 | 
					 | 
				
			||||||
        input_ids = jnp.zeros(input_shape, dtype="i4")
 | 
					 | 
				
			||||||
        attention_mask = jnp.ones_like(input_ids)
 | 
					 | 
				
			||||||
        params_rng, dropout_rng = jax.random.split(rng)
 | 
					 | 
				
			||||||
        rngs = {"params": params_rng, "dropout": dropout_rng}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if params is not None:
 | 
					 | 
				
			||||||
            random_params = flatten_dict(unfreeze(random_params))
 | 
					 | 
				
			||||||
            params = flatten_dict(unfreeze(params))
 | 
					 | 
				
			||||||
            for missing_key in self._missing_keys:
 | 
					 | 
				
			||||||
                params[missing_key] = random_params[missing_key]
 | 
					 | 
				
			||||||
            self._missing_keys = set()
 | 
					 | 
				
			||||||
            return freeze(unflatten_dict(params))
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return random_params
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def init_cache(self, batch_size, max_length):
 | 
					 | 
				
			||||||
        r"""
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
            batch_size (`int`):
 | 
					 | 
				
			||||||
                batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
 | 
					 | 
				
			||||||
            max_length (`int`):
 | 
					 | 
				
			||||||
                maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
 | 
					 | 
				
			||||||
                cache.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        # init input variables to retrieve cache
 | 
					 | 
				
			||||||
        input_ids = jnp.ones((batch_size, max_length), dtype="i4")
 | 
					 | 
				
			||||||
        attention_mask = jnp.ones_like(input_ids)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        init_variables = self.module.init(
 | 
					 | 
				
			||||||
            jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        return unfreeze(init_variables["cache"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
 | 
					 | 
				
			||||||
    def __call__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        input_ids,
 | 
					 | 
				
			||||||
        attention_mask=None,
 | 
					 | 
				
			||||||
        past_key_values: dict = None,
 | 
					 | 
				
			||||||
        params: dict = None,
 | 
					 | 
				
			||||||
        dropout_rng: jax.random.PRNGKey = None,
 | 
					 | 
				
			||||||
        train: bool = False,
 | 
					 | 
				
			||||||
        output_attentions: Optional[bool] = None,
 | 
					 | 
				
			||||||
        output_hidden_states: Optional[bool] = None,
 | 
					 | 
				
			||||||
        return_dict: Optional[bool] = None,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
					 | 
				
			||||||
        output_hidden_states = (
 | 
					 | 
				
			||||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        batch_size, sequence_length = input_ids.shape
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if attention_mask is None:
 | 
					 | 
				
			||||||
            attention_mask = jnp.ones((batch_size, sequence_length))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Handle any PRNG if needed
 | 
					 | 
				
			||||||
        rngs = {}
 | 
					 | 
				
			||||||
        if dropout_rng is not None:
 | 
					 | 
				
			||||||
            rngs["dropout"] = dropout_rng
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        inputs = {"params": params or self.params}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
 | 
					 | 
				
			||||||
        # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
 | 
					 | 
				
			||||||
        # changed by FlaxBloomAttention module
 | 
					 | 
				
			||||||
        if past_key_values:
 | 
					 | 
				
			||||||
            inputs["cache"] = past_key_values
 | 
					 | 
				
			||||||
            mutable = ["cache"]
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            mutable = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        outputs = self.module.apply(
 | 
					 | 
				
			||||||
            inputs,
 | 
					 | 
				
			||||||
            jnp.array(input_ids, dtype="i4"),
 | 
					 | 
				
			||||||
            jnp.array(attention_mask, dtype="i4"),
 | 
					 | 
				
			||||||
            not train,
 | 
					 | 
				
			||||||
            False,
 | 
					 | 
				
			||||||
            output_attentions,
 | 
					 | 
				
			||||||
            output_hidden_states,
 | 
					 | 
				
			||||||
            return_dict,
 | 
					 | 
				
			||||||
            rngs=rngs,
 | 
					 | 
				
			||||||
            mutable=mutable,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # add updated cache to model output
 | 
					 | 
				
			||||||
        if past_key_values is not None and return_dict:
 | 
					 | 
				
			||||||
            outputs, past_key_values = outputs
 | 
					 | 
				
			||||||
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
 | 
					 | 
				
			||||||
            return outputs
 | 
					 | 
				
			||||||
        elif past_key_values is not None and not return_dict:
 | 
					 | 
				
			||||||
            outputs, past_key_values = outputs
 | 
					 | 
				
			||||||
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return outputs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FlaxBloomBlockCollection(nn.Module):
 | 
					 | 
				
			||||||
    config: BloomConfig
 | 
					 | 
				
			||||||
    dtype: jnp.dtype = jnp.float32
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setup(self):
 | 
					 | 
				
			||||||
        self.layers = [
 | 
					 | 
				
			||||||
            FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype)
 | 
					 | 
				
			||||||
            for layer_number in range(self.config.num_hidden_layers)
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        hidden_states,
 | 
					 | 
				
			||||||
        alibi,
 | 
					 | 
				
			||||||
        attention_mask=None,
 | 
					 | 
				
			||||||
        deterministic: bool = True,
 | 
					 | 
				
			||||||
        init_cache: bool = False,
 | 
					 | 
				
			||||||
        output_attentions: bool = False,
 | 
					 | 
				
			||||||
        output_hidden_states: bool = False,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        all_attentions = () if output_attentions else None
 | 
					 | 
				
			||||||
        all_hidden_states = () if output_hidden_states else None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for layer_number in range(self.config.num_hidden_layers):
 | 
					 | 
				
			||||||
            if output_hidden_states:
 | 
					 | 
				
			||||||
                all_hidden_states += (hidden_states,)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            layer_outputs = self.layers[layer_number](
 | 
					 | 
				
			||||||
                hidden_states,
 | 
					 | 
				
			||||||
                alibi=alibi,
 | 
					 | 
				
			||||||
                attention_mask=attention_mask,
 | 
					 | 
				
			||||||
                deterministic=deterministic,
 | 
					 | 
				
			||||||
                init_cache=init_cache,
 | 
					 | 
				
			||||||
                output_attentions=output_attentions,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            hidden_states = layer_outputs[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if output_attentions:
 | 
					 | 
				
			||||||
                all_attentions += (layer_outputs[1],)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # this contains possible `None` values - `FlaxBloomModule` will filter them out
 | 
					 | 
				
			||||||
        outputs = (hidden_states, all_hidden_states, all_attentions)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return outputs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FlaxBloomModule(nn.Module):
 | 
					 | 
				
			||||||
    config: BloomConfig
 | 
					 | 
				
			||||||
    dtype: jnp.dtype = jnp.float32
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setup(self):
 | 
					 | 
				
			||||||
        self.embed_dim = self.config.hidden_size
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # word embeddings (no positional embedding layer)
 | 
					 | 
				
			||||||
        self.word_embeddings = nn.Embed(
 | 
					 | 
				
			||||||
            self.config.vocab_size,
 | 
					 | 
				
			||||||
            self.embed_dim,
 | 
					 | 
				
			||||||
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
 | 
					 | 
				
			||||||
            dtype=self.dtype,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # post-embedding layernorm
 | 
					 | 
				
			||||||
        self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # transformer layers
 | 
					 | 
				
			||||||
        self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # final layernorm
 | 
					 | 
				
			||||||
        self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        input_ids=None,
 | 
					 | 
				
			||||||
        attention_mask=None,
 | 
					 | 
				
			||||||
        deterministic=True,
 | 
					 | 
				
			||||||
        init_cache: bool = False,
 | 
					 | 
				
			||||||
        output_attentions: bool = False,
 | 
					 | 
				
			||||||
        output_hidden_states: bool = False,
 | 
					 | 
				
			||||||
        return_dict: bool = True,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        inputs_embeds = self.word_embeddings(input_ids)
 | 
					 | 
				
			||||||
        # do post-embedding layernorm
 | 
					 | 
				
			||||||
        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # build alibi depending on `attention_mask`
 | 
					 | 
				
			||||||
        alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        outputs = self.h(
 | 
					 | 
				
			||||||
            hidden_states,
 | 
					 | 
				
			||||||
            alibi=alibi,
 | 
					 | 
				
			||||||
            attention_mask=attention_mask,
 | 
					 | 
				
			||||||
            deterministic=deterministic,
 | 
					 | 
				
			||||||
            init_cache=init_cache,
 | 
					 | 
				
			||||||
            output_hidden_states=output_hidden_states,
 | 
					 | 
				
			||||||
            output_attentions=output_attentions,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        hidden_states = outputs[0]
 | 
					 | 
				
			||||||
        hidden_states = self.ln_f(hidden_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if output_hidden_states:
 | 
					 | 
				
			||||||
            all_hidden_states = outputs[1] + (hidden_states,)
 | 
					 | 
				
			||||||
            outputs = (hidden_states, all_hidden_states) + outputs[2:]
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            outputs = (hidden_states,) + outputs[1:]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not return_dict:
 | 
					 | 
				
			||||||
            return tuple(v for v in [outputs[0], outputs[-1]] if v is not None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return FlaxBaseModelOutputWithPastAndCrossAttentions(
 | 
					 | 
				
			||||||
            last_hidden_state=hidden_states,
 | 
					 | 
				
			||||||
            hidden_states=outputs[1],
 | 
					 | 
				
			||||||
            attentions=outputs[-1],
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@add_start_docstrings(
 | 
					 | 
				
			||||||
    "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
 | 
					 | 
				
			||||||
    BLOOM_START_DOCSTRING,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom
 | 
					 | 
				
			||||||
class FlaxBloomModel(FlaxBloomPreTrainedModel):
 | 
					 | 
				
			||||||
    module_class = FlaxBloomModule
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
append_call_sample_docstring(FlaxBloomModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FlaxBloomForCausalLMModule(nn.Module):
 | 
					 | 
				
			||||||
    config: BloomConfig
 | 
					 | 
				
			||||||
    dtype: jnp.dtype = jnp.float32
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setup(self):
 | 
					 | 
				
			||||||
        self.transformer = FlaxBloomModule(self.config, dtype=self.dtype)
 | 
					 | 
				
			||||||
        self.lm_head = nn.Dense(
 | 
					 | 
				
			||||||
            self.config.vocab_size,
 | 
					 | 
				
			||||||
            use_bias=False,
 | 
					 | 
				
			||||||
            dtype=self.dtype,
 | 
					 | 
				
			||||||
            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        input_ids,
 | 
					 | 
				
			||||||
        attention_mask,
 | 
					 | 
				
			||||||
        deterministic: bool = True,
 | 
					 | 
				
			||||||
        init_cache: bool = False,
 | 
					 | 
				
			||||||
        output_attentions: bool = False,
 | 
					 | 
				
			||||||
        output_hidden_states: bool = False,
 | 
					 | 
				
			||||||
        return_dict: bool = True,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        outputs = self.transformer(
 | 
					 | 
				
			||||||
            input_ids,
 | 
					 | 
				
			||||||
            attention_mask=attention_mask,
 | 
					 | 
				
			||||||
            deterministic=deterministic,
 | 
					 | 
				
			||||||
            init_cache=init_cache,
 | 
					 | 
				
			||||||
            output_attentions=output_attentions,
 | 
					 | 
				
			||||||
            output_hidden_states=output_hidden_states,
 | 
					 | 
				
			||||||
            return_dict=return_dict,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        hidden_states = outputs[0]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self.config.tie_word_embeddings:
 | 
					 | 
				
			||||||
            shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T
 | 
					 | 
				
			||||||
            lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            lm_logits = self.lm_head(hidden_states)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not return_dict:
 | 
					 | 
				
			||||||
            return (lm_logits,) + outputs[1:]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@add_start_docstrings(
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
 | 
					 | 
				
			||||||
    embeddings).
 | 
					 | 
				
			||||||
    """,
 | 
					 | 
				
			||||||
    BLOOM_START_DOCSTRING,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel):
 | 
					 | 
				
			||||||
    module_class = FlaxBloomForCausalLMModule
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
 | 
					 | 
				
			||||||
        # initializing the cache
 | 
					 | 
				
			||||||
        batch_size, seq_length = input_ids.shape
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        past_key_values = self.init_cache(batch_size, max_length)
 | 
					 | 
				
			||||||
        # Note that usually one would have to put 0's in the attention_mask for
 | 
					 | 
				
			||||||
        # x > input_ids.shape[-1] and x < cache_length. But since Bloom uses a causal mask,
 | 
					 | 
				
			||||||
        # those positions are masked anyway. Thus, we can create a single static attention_mask here,
 | 
					 | 
				
			||||||
        # which is more efficient for compilation
 | 
					 | 
				
			||||||
        extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
 | 
					 | 
				
			||||||
        if attention_mask is not None:
 | 
					 | 
				
			||||||
            extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return {
 | 
					 | 
				
			||||||
            "past_key_values": past_key_values,
 | 
					 | 
				
			||||||
            "attention_mask": extended_attention_mask,
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def update_inputs_for_generation(self, model_outputs, model_kwargs):
 | 
					 | 
				
			||||||
        model_kwargs["past_key_values"] = model_outputs.past_key_values
 | 
					 | 
				
			||||||
        return model_kwargs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC)
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,9 +0,0 @@
 | 
				
			||||||
{
 | 
					 | 
				
			||||||
  "additional_special_tokens": [
 | 
					 | 
				
			||||||
    "[GEN]"
 | 
					 | 
				
			||||||
  ],
 | 
					 | 
				
			||||||
  "bos_token": "<s>",
 | 
					 | 
				
			||||||
  "eos_token": "</s>",
 | 
					 | 
				
			||||||
  "pad_token": "<pad>",
 | 
					 | 
				
			||||||
  "unk_token": "<unk>"
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
| 
						 | 
					@ -1,177 +0,0 @@
 | 
				
			||||||
# coding=utf-8
 | 
					 | 
				
			||||||
# Copyright 2022 The HuggingFace Inc. team.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
# You may obtain a copy of the License at
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
# limitations under the License.
 | 
					 | 
				
			||||||
"""Tokenization classes for Bloom."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import pickle
 | 
					 | 
				
			||||||
from typing import Optional, Tuple
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from transformers.tokenization_utils_base import BatchEncoding
 | 
					 | 
				
			||||||
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
 | 
					 | 
				
			||||||
from transformers.utils import logging
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
PRETRAINED_VOCAB_FILES_MAP = {
 | 
					 | 
				
			||||||
    "tokenizer_file": {
 | 
					 | 
				
			||||||
        "bigscience/tokenizer": "https://huggingface.co/bigscience/tokenizer/blob/main/tokenizer.json",
 | 
					 | 
				
			||||||
        "bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/tokenizer.json",
 | 
					 | 
				
			||||||
        "bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/tokenizer.json",
 | 
					 | 
				
			||||||
        "bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/tokenizer.json",
 | 
					 | 
				
			||||||
        "bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/tokenizer.json",
 | 
					 | 
				
			||||||
        "bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/tokenizer.json",
 | 
					 | 
				
			||||||
        "bigscience/bloom": "https://huggingface.co/bigscience/bloom/blob/main/tokenizer.json",
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class BloomTokenizerFast(PreTrainedTokenizerFast):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Construct a "fast" Bloom tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
 | 
					 | 
				
			||||||
    Byte-Pair-Encoding.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
 | 
					 | 
				
			||||||
    be encoded differently whether it is at the beginning of the sentence (without space) or not:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ```python
 | 
					 | 
				
			||||||
    >>> from transformers import BloomTokenizerFast
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    >>> tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom")
 | 
					 | 
				
			||||||
    >>> tokenizer("Hello world")["input_ids"]
 | 
					 | 
				
			||||||
    [59414, 8876]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    >>> tokenizer(" Hello world")["input_ids"]
 | 
					 | 
				
			||||||
    [86153, 8876]
 | 
					 | 
				
			||||||
    ```
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
 | 
					 | 
				
			||||||
    the model was not pretrained this way, it might yield a decrease in performance.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    <Tip>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    </Tip>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
 | 
					 | 
				
			||||||
    refer to this superclass for more information regarding those methods.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    Args:
 | 
					 | 
				
			||||||
        vocab_file (`str`):
 | 
					 | 
				
			||||||
            Path to the vocabulary file.
 | 
					 | 
				
			||||||
        merges_file (`str`):
 | 
					 | 
				
			||||||
            Path to the merges file.
 | 
					 | 
				
			||||||
        errors (`str`, *optional*, defaults to `"replace"`):
 | 
					 | 
				
			||||||
            Paradigm to follow when decoding bytes to UTF-8. See
 | 
					 | 
				
			||||||
            [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
 | 
					 | 
				
			||||||
        unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
 | 
					 | 
				
			||||||
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
 | 
					 | 
				
			||||||
            token instead.
 | 
					 | 
				
			||||||
        bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
 | 
					 | 
				
			||||||
            The beginning of sequence token.
 | 
					 | 
				
			||||||
        eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
 | 
					 | 
				
			||||||
            The end of sequence token.
 | 
					 | 
				
			||||||
        add_prefix_space (`bool`, *optional*, defaults to `False`):
 | 
					 | 
				
			||||||
            Whether or not to add an initial space to the input. This allows to treat the leading word just as any
 | 
					 | 
				
			||||||
            other word. (Bloom tokenizer detect beginning of words by the preceding space).
 | 
					 | 
				
			||||||
        trim_offsets (`bool`, *optional*, defaults to `True`):
 | 
					 | 
				
			||||||
            Whether or not the post-processing step should trim offsets to avoid including whitespaces.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    vocab_files_names = VOCAB_FILES_NAMES
 | 
					 | 
				
			||||||
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
 | 
					 | 
				
			||||||
    model_input_names = ["input_ids", "attention_mask"]
 | 
					 | 
				
			||||||
    slow_tokenizer_class = None
 | 
					 | 
				
			||||||
    # No `max_model_input_sizes` as BLOOM uses ALiBi positional embeddings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        vocab_file=None,
 | 
					 | 
				
			||||||
        merges_file=None,
 | 
					 | 
				
			||||||
        tokenizer_file=None,
 | 
					 | 
				
			||||||
        unk_token="<unk>",
 | 
					 | 
				
			||||||
        bos_token="<s>",
 | 
					 | 
				
			||||||
        eos_token="</s>",
 | 
					 | 
				
			||||||
        pad_token="<pad>",
 | 
					 | 
				
			||||||
        add_prefix_space=False,
 | 
					 | 
				
			||||||
        clean_up_tokenization_spaces=False,
 | 
					 | 
				
			||||||
        **kwargs,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        super().__init__(
 | 
					 | 
				
			||||||
            vocab_file,
 | 
					 | 
				
			||||||
            merges_file,
 | 
					 | 
				
			||||||
            tokenizer_file=tokenizer_file,
 | 
					 | 
				
			||||||
            unk_token=unk_token,
 | 
					 | 
				
			||||||
            bos_token=bos_token,
 | 
					 | 
				
			||||||
            eos_token=eos_token,
 | 
					 | 
				
			||||||
            pad_token=pad_token,
 | 
					 | 
				
			||||||
            add_prefix_space=add_prefix_space,
 | 
					 | 
				
			||||||
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
 | 
					 | 
				
			||||||
            **kwargs,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        # TODO @ArthurZucker this can only work one way for now, to update later-on. Tests should also properly
 | 
					 | 
				
			||||||
        # check this as they were green before.
 | 
					 | 
				
			||||||
        pre_tok_state = pickle.dumps(self.backend_tokenizer.pre_tokenizer)
 | 
					 | 
				
			||||||
        decoder_state = pickle.dumps(self.backend_tokenizer.decoder)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if add_prefix_space:
 | 
					 | 
				
			||||||
            pre_tok_state = pre_tok_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
 | 
					 | 
				
			||||||
            decoder_state = decoder_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
 | 
					 | 
				
			||||||
        self.backend_tokenizer.pre_tokenizer = pickle.loads(pre_tok_state)
 | 
					 | 
				
			||||||
        self.backend_tokenizer.decoder = pickle.loads(decoder_state)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.add_prefix_space = add_prefix_space
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
 | 
					 | 
				
			||||||
        is_split_into_words = kwargs.get("is_split_into_words", False)
 | 
					 | 
				
			||||||
        if not (self.add_prefix_space or not is_split_into_words):
 | 
					 | 
				
			||||||
            raise Exception(
 | 
					 | 
				
			||||||
                f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
 | 
					 | 
				
			||||||
                " pretokenized inputs."
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return super()._batch_encode_plus(*args, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
 | 
					 | 
				
			||||||
        is_split_into_words = kwargs.get("is_split_into_words", False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if not (self.add_prefix_space or not is_split_into_words):
 | 
					 | 
				
			||||||
            raise Exception(
 | 
					 | 
				
			||||||
                f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
 | 
					 | 
				
			||||||
                " pretokenized inputs."
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return super()._encode_plus(*args, **kwargs)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
 | 
					 | 
				
			||||||
        files = self._tokenizer.model.save(save_directory, name=filename_prefix)
 | 
					 | 
				
			||||||
        return tuple(files)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
 | 
					 | 
				
			||||||
    def default_chat_template(self):
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        A simple chat template that ignores role information and just concatenates messages with EOS tokens.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        logger.warning_once(
 | 
					 | 
				
			||||||
            "\nNo chat template is defined for this tokenizer - using the default template "
 | 
					 | 
				
			||||||
            f"for the {self.__class__.__name__} class. If the default is not appropriate for "
 | 
					 | 
				
			||||||
            "your model, please set `tokenizer.chat_template` to an appropriate template. "
 | 
					 | 
				
			||||||
            "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
 | 
					 | 
				
			||||||
							
								
								
									
										501215
									
								
								seqgpt/tokenizer.json
								
								
								
								
							
							
						
						
									
										501215
									
								
								seqgpt/tokenizer.json
								
								
								
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
					@ -1,11 +0,0 @@
 | 
				
			||||||
{
 | 
					 | 
				
			||||||
  "add_prefix_space": false,
 | 
					 | 
				
			||||||
  "bos_token": "<s>",
 | 
					 | 
				
			||||||
  "clean_up_tokenization_spaces": false,
 | 
					 | 
				
			||||||
  "eos_token": "</s>",
 | 
					 | 
				
			||||||
  "model_max_length": 1000000000000000019884624838656,
 | 
					 | 
				
			||||||
  "pad_token": "<pad>",
 | 
					 | 
				
			||||||
  "padding_side": "left",
 | 
					 | 
				
			||||||
  "tokenizer_class": "BloomTokenizer",
 | 
					 | 
				
			||||||
  "unk_token": "<unk>"
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
										
											Binary file not shown.
										
									
								
							
		Loading…
	
		Reference in New Issue