Delete no used files.
This commit is contained in:
parent
a15e55bead
commit
89c12380cb
File diff suppressed because it is too large
Load Diff
|
@ -1,30 +0,0 @@
|
||||||
{
|
|
||||||
"architectures": [
|
|
||||||
"T5ForConditionalGeneration"
|
|
||||||
],
|
|
||||||
"d_ff": 2048,
|
|
||||||
"d_kv": 64,
|
|
||||||
"d_model": 768,
|
|
||||||
"decoder_start_token_id": 0,
|
|
||||||
"dense_act_fn": "gelu_new",
|
|
||||||
"dropout_rate": 0.1,
|
|
||||||
"eos_token_id": 1,
|
|
||||||
"feed_forward_proj": "gated-gelu",
|
|
||||||
"initializer_factor": 1.0,
|
|
||||||
"is_encoder_decoder": true,
|
|
||||||
"is_gated_act": true,
|
|
||||||
"layer_norm_epsilon": 1e-06,
|
|
||||||
"model_type": "t5",
|
|
||||||
"num_decoder_layers": 12,
|
|
||||||
"num_heads": 12,
|
|
||||||
"num_layers": 12,
|
|
||||||
"output_past": true,
|
|
||||||
"pad_token_id": 0,
|
|
||||||
"relative_attention_max_distance": 128,
|
|
||||||
"relative_attention_num_buckets": 32,
|
|
||||||
"tie_word_embeddings": false,
|
|
||||||
"torch_dtype": "float32",
|
|
||||||
"transformers_version": "4.26.0.dev0",
|
|
||||||
"use_cache": true,
|
|
||||||
"vocab_size": 32128
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
{
|
|
||||||
"framework": "pytorch",
|
|
||||||
"task": "text2text-generation",
|
|
||||||
"model": {
|
|
||||||
"type": "T5",
|
|
||||||
"language": "zh"
|
|
||||||
},
|
|
||||||
"pipeline": {
|
|
||||||
"type": "text2text-generation"
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,175 +0,0 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
||||||
# Copyright 2020, The T5 Authors and HuggingFace Inc.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
""" T5 model configuration"""
|
|
||||||
from typing import Mapping
|
|
||||||
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.onnx import OnnxSeq2SeqConfigWithPast
|
|
||||||
|
|
||||||
from modelscope.utils.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class T5Config(PretrainedConfig):
|
|
||||||
r"""
|
|
||||||
This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
|
|
||||||
instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
|
|
||||||
configuration with the defaults will yield a similar configuration to that of the T5
|
|
||||||
[t5-small](https://huggingface.co/t5-small) architecture.
|
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
vocab_size (`int`, *optional*, defaults to 32128):
|
|
||||||
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
|
|
||||||
`inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
|
|
||||||
d_model (`int`, *optional*, defaults to 512):
|
|
||||||
Size of the encoder layers and the pooler layer.
|
|
||||||
d_kv (`int`, *optional*, defaults to 64):
|
|
||||||
Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
|
|
||||||
num_heads`.
|
|
||||||
d_ff (`int`, *optional*, defaults to 2048):
|
|
||||||
Size of the intermediate feed forward layer in each `T5Block`.
|
|
||||||
num_layers (`int`, *optional*, defaults to 6):
|
|
||||||
Number of hidden layers in the Transformer encoder.
|
|
||||||
num_decoder_layers (`int`, *optional*):
|
|
||||||
Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
|
|
||||||
num_heads (`int`, *optional*, defaults to 8):
|
|
||||||
Number of attention heads for each attention layer in the Transformer encoder.
|
|
||||||
relative_attention_num_buckets (`int`, *optional*, defaults to 32):
|
|
||||||
The number of buckets to use for each attention layer.
|
|
||||||
relative_attention_max_distance (`int`, *optional*, defaults to 128):
|
|
||||||
The maximum distance of the longer sequences for the bucket separation.
|
|
||||||
dropout_rate (`float`, *optional*, defaults to 0.1):
|
|
||||||
The ratio for all dropout layers.
|
|
||||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
|
||||||
The epsilon used by the layer normalization layers.
|
|
||||||
initializer_factor (`float`, *optional*, defaults to 1):
|
|
||||||
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
|
||||||
testing).
|
|
||||||
feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
|
|
||||||
Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
|
|
||||||
`"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
|
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
|
||||||
"""
|
|
||||||
model_type = 't5'
|
|
||||||
keys_to_ignore_at_inference = ['past_key_values']
|
|
||||||
attribute_map = {
|
|
||||||
'hidden_size': 'd_model',
|
|
||||||
'num_attention_heads': 'num_heads',
|
|
||||||
'num_hidden_layers': 'num_layers'
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
vocab_size=32128,
|
|
||||||
d_model=512,
|
|
||||||
d_kv=64,
|
|
||||||
d_ff=2048,
|
|
||||||
num_layers=6,
|
|
||||||
num_decoder_layers=None,
|
|
||||||
num_heads=8,
|
|
||||||
relative_attention_num_buckets=32,
|
|
||||||
relative_attention_max_distance=128,
|
|
||||||
dropout_rate=0.1,
|
|
||||||
layer_norm_epsilon=1e-6,
|
|
||||||
initializer_factor=1.0,
|
|
||||||
feed_forward_proj='relu',
|
|
||||||
is_encoder_decoder=True,
|
|
||||||
use_cache=True,
|
|
||||||
pad_token_id=0,
|
|
||||||
eos_token_id=1,
|
|
||||||
**kwargs):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.d_model = d_model
|
|
||||||
self.d_kv = d_kv
|
|
||||||
self.d_ff = d_ff
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.num_decoder_layers = (num_decoder_layers if num_decoder_layers
|
|
||||||
is not None else self.num_layers
|
|
||||||
) # default = symmetry
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.relative_attention_num_buckets = relative_attention_num_buckets
|
|
||||||
self.relative_attention_max_distance = relative_attention_max_distance
|
|
||||||
self.dropout_rate = dropout_rate
|
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
|
||||||
self.initializer_factor = initializer_factor
|
|
||||||
self.feed_forward_proj = feed_forward_proj
|
|
||||||
self.use_cache = use_cache
|
|
||||||
|
|
||||||
act_info = self.feed_forward_proj.split('-')
|
|
||||||
self.dense_act_fn = act_info[-1]
|
|
||||||
self.is_gated_act = act_info[0] == 'gated'
|
|
||||||
|
|
||||||
if len(act_info) > 1 and act_info[0] != 'gated' or len(act_info) > 2:
|
|
||||||
raise ValueError(
|
|
||||||
f'`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer.'
|
|
||||||
'Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. '
|
|
||||||
"'gated-gelu' or 'relu'")
|
|
||||||
|
|
||||||
# for backwards compatibility
|
|
||||||
if feed_forward_proj == 'gated-gelu':
|
|
||||||
self.dense_act_fn = 'gelu_new'
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
pad_token_id=pad_token_id,
|
|
||||||
eos_token_id=eos_token_id,
|
|
||||||
is_encoder_decoder=is_encoder_decoder,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class T5OnnxConfig(OnnxSeq2SeqConfigWithPast):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
|
||||||
common_inputs = {
|
|
||||||
'input_ids': {
|
|
||||||
0: 'batch',
|
|
||||||
1: 'encoder_sequence'
|
|
||||||
},
|
|
||||||
'attention_mask': {
|
|
||||||
0: 'batch',
|
|
||||||
1: 'encoder_sequence'
|
|
||||||
},
|
|
||||||
}
|
|
||||||
if self.use_past:
|
|
||||||
common_inputs['attention_mask'][
|
|
||||||
1] = 'past_encoder_sequence + sequence'
|
|
||||||
common_inputs['decoder_input_ids'] = {0: 'batch'}
|
|
||||||
common_inputs['decoder_attention_mask'] = {
|
|
||||||
0: 'batch',
|
|
||||||
1: 'past_decoder_sequence + sequence'
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
common_inputs['decoder_input_ids'] = {
|
|
||||||
0: 'batch',
|
|
||||||
1: 'decoder_sequence'
|
|
||||||
}
|
|
||||||
common_inputs['decoder_attention_mask'] = {
|
|
||||||
0: 'batch',
|
|
||||||
1: 'decoder_sequence'
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.use_past:
|
|
||||||
self.fill_with_past_key_values_(common_inputs, direction='inputs')
|
|
||||||
|
|
||||||
return common_inputs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_onnx_opset(self) -> int:
|
|
||||||
return 13
|
|
|
@ -1,89 +0,0 @@
|
||||||
import torch
|
|
||||||
from modelscope.pipelines import pipeline
|
|
||||||
from modelscope.utils.constant import Tasks
|
|
||||||
|
|
||||||
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
|
|
||||||
|
|
||||||
from modeling_t5 import T5ForConditionalGeneration
|
|
||||||
from modelscope.utils.config import Config
|
|
||||||
from configuration import T5Config
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
|
|
||||||
seed = 4321
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
model_dir = snapshot_download("ClueAI/PromptCLUE-base-v1-5")
|
|
||||||
|
|
||||||
# config = T5Config()
|
|
||||||
|
|
||||||
config, kwargs = AutoConfig.from_pretrained(
|
|
||||||
model_dir,
|
|
||||||
return_unused_kwargs=True,
|
|
||||||
trust_remote_code=True,
|
|
||||||
code_revision=None,
|
|
||||||
_commit_hash=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
model = T5ForConditionalGeneration(config)
|
|
||||||
model = model.from_pretrained(model_dir)
|
|
||||||
|
|
||||||
preprocessor = TextGenerationTransformersPreprocessor(model_dir)
|
|
||||||
|
|
||||||
out = preprocessor._tokenize_text("生成与下列文字相同意")
|
|
||||||
|
|
||||||
tokenizer = preprocessor.nlp_tokenizer
|
|
||||||
response, history = model.chat(tokenizer, "生成与下列文字相同意", history=[])
|
|
||||||
|
|
||||||
# model = T5ForConditionalGeneration.from_pretrained(
|
|
||||||
# "ClueAI/PromptCLUE-base-v1-5", revision="v0.1"
|
|
||||||
# )
|
|
||||||
|
|
||||||
pipeline_t2t = pipeline(
|
|
||||||
task=Tasks.text2text_generation, model=model, preprocessor=preprocessor
|
|
||||||
)
|
|
||||||
|
|
||||||
print(pipeline_t2t("生成与下列文字相同意思的句子:\n白云遍地无人扫\n答案:", do_sample=True, top_p=0.8))
|
|
||||||
# {'text': '白云散去无踪,没人扫。'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t('改写下面的文字,确保意思相同:\n一个如此藐视本国人民民主权利的人,怎么可能捍卫外国人的民权?\n答案:', do_sample=True, top_p=0.8))
|
|
||||||
# # {'text': '对一个如此藐视本国人民民主权利的人,怎么能捍卫外国人的民权?'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t('根据问题给出答案:\n问题:手指发麻的主要可能病因是:\n答案'))
|
|
||||||
# # {'text': '神经损伤,颈椎病,贫血,高血压'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t('问答:\n问题:黄果悬钩子的目是:\n答案:'))
|
|
||||||
# # {'text': '蔷薇目'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t('情感分析:\n这个看上去还可以,但其实我不喜欢\n选项:积极,消极'))
|
|
||||||
# # {'text': '消极'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t("下面句子是否表示了相同的语义:\n文本1:糖尿病腿麻木怎么办?\n文本2:糖尿病怎样控制生活方式\n选项:相似,不相似\n答案:"))
|
|
||||||
# # {'text': '不相似'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t('这是关于哪方面的新闻:\n如果日本沉没,中国会接收日本难民吗?\n选项:故事,文化,娱乐,体育,财经,房产,汽车,教育,科技,军事,旅游,国际,股票,农业,游戏'))
|
|
||||||
# # {'text': '国际'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t("阅读文本抽取关键信息:\n张玄武1990年出生中国国籍无境外居留权博士学历现任杭州线锁科技技术总监。\n问题:机构,人名,职位,籍贯,专业,国籍,学历,种族\n答案:"))
|
|
||||||
# # {'text': '机构:杭州线锁科技技术_人名:张玄武_职位:博士学历'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t("翻译成英文:\n杀不死我的只会让我更强大\n答案:"))
|
|
||||||
# # {'text': 'To kill my life only let me stronger'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t('为下面的文章生成摘要:\n北京时间9月5日12时52分,四川甘孜藏族自治州泸定县发生6.8级地震。地震发生后,领导高度重视并作出重要指示,要求把抢救生命作为首要任务,全力救援受灾群众,最大限度减少人员伤亡'))
|
|
||||||
# # {'text': '四川甘孜发生6.8级地震'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t("推理关系判断:\n前提:小明今天在北京\n假设:小明在深圳旅游\n选项:矛盾,蕴含,中立\n答案:"))
|
|
||||||
# # {'text': '蕴涵'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t('阅读以下对话并回答问题。\n男:今天怎么这么晚才来上班啊?女:昨天工作到很晚,而且我还感冒了。男:那你回去休息吧,我帮你请假。女:谢谢你。\n问题:女的怎么样?\n选项:正在工作,感冒了,在打电话,要出差。'))
|
|
||||||
# # {'text': '感冒了'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t("文本纠错:\n告诉二营长,叫他彻回来,我李云龙从不打没有准备的杖\n答案:"))
|
|
||||||
# #{'text':'告诉二营长,叫他下来,我李云龙从不打没有准备的仗'}
|
|
||||||
|
|
||||||
# print(pipeline_t2t("问答:\n问题:小米的创始人是谁?\n答案:"))
|
|
||||||
# # {'text': '小米创始人:雷军'}
|
|
|
@ -1,406 +0,0 @@
|
||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
||||||
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import copy
|
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union, List, Dict, Any
|
|
||||||
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
|
||||||
|
|
||||||
from modelscope.metainfo import Models
|
|
||||||
from modelscope.models.builder import MODELS
|
|
||||||
from modelscope.outputs import (
|
|
||||||
AttentionBackboneModelOutput,
|
|
||||||
Seq2SeqLMOutput,
|
|
||||||
TokenGeneratorOutput,
|
|
||||||
)
|
|
||||||
from modelscope.utils.constant import Tasks
|
|
||||||
from modelscope.utils.logger import get_logger
|
|
||||||
from backbone import T5PreTrainedModel, T5Stack
|
|
||||||
from configuration import T5Config
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
|
||||||
__HEAD_MASK_WARNING_MSG = """
|
|
||||||
The input argument `head_mask` was split into two arguments `head_mask` and
|
|
||||||
`decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`,
|
|
||||||
but this feature is deprecated and will be removed in future versions. If you do
|
|
||||||
not want to use any `decoder_head_mask` now, please set `decoder_head_mask =
|
|
||||||
torch.ones(num_layers, num_heads)`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class T5ForConditionalGeneration(T5PreTrainedModel):
|
|
||||||
_keys_to_ignore_on_load_missing = [
|
|
||||||
r"encoder\.embed_tokens\.weight",
|
|
||||||
r"decoder\.embed_tokens\.weight",
|
|
||||||
r"lm_head\.weight",
|
|
||||||
]
|
|
||||||
_keys_to_ignore_on_load_unexpected = [
|
|
||||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, config: T5Config, device_map=None, **kwargs):
|
|
||||||
super().__init__(config)
|
|
||||||
self.model_dim = config.d_model
|
|
||||||
|
|
||||||
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
|
||||||
|
|
||||||
encoder_config = copy.deepcopy(config)
|
|
||||||
encoder_config.is_decoder = False
|
|
||||||
encoder_config.use_cache = False
|
|
||||||
encoder_config.is_encoder_decoder = False
|
|
||||||
self.encoder = T5Stack(encoder_config, self.shared)
|
|
||||||
|
|
||||||
decoder_config = copy.deepcopy(config)
|
|
||||||
decoder_config.is_decoder = True
|
|
||||||
decoder_config.is_encoder_decoder = False
|
|
||||||
decoder_config.num_layers = config.num_decoder_layers
|
|
||||||
self.decoder = T5Stack(decoder_config, self.shared)
|
|
||||||
|
|
||||||
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
# Model parallel
|
|
||||||
self.model_parallel = False
|
|
||||||
if device_map == "auto":
|
|
||||||
self.parallelize()
|
|
||||||
|
|
||||||
def parallelize(self, device_map=None):
|
|
||||||
self.device_map = (
|
|
||||||
get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
|
|
||||||
if device_map is None
|
|
||||||
else device_map
|
|
||||||
)
|
|
||||||
assert_device_map(self.device_map, len(self.encoder.block))
|
|
||||||
self.encoder.parallelize(self.device_map)
|
|
||||||
self.decoder.parallelize(self.device_map)
|
|
||||||
self.lm_head = self.lm_head.to(self.decoder.first_device)
|
|
||||||
self.model_parallel = True
|
|
||||||
|
|
||||||
def deparallelize(self):
|
|
||||||
self.encoder.deparallelize()
|
|
||||||
self.decoder.deparallelize()
|
|
||||||
self.encoder = self.encoder.to("cpu")
|
|
||||||
self.decoder = self.decoder.to("cpu")
|
|
||||||
self.lm_head = self.lm_head.to("cpu")
|
|
||||||
self.model_parallel = False
|
|
||||||
self.device_map = None
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.shared
|
|
||||||
|
|
||||||
def set_input_embeddings(self, new_embeddings):
|
|
||||||
self.shared = new_embeddings
|
|
||||||
self.encoder.set_input_embeddings(new_embeddings)
|
|
||||||
self.decoder.set_input_embeddings(new_embeddings)
|
|
||||||
|
|
||||||
def set_output_embeddings(self, new_embeddings):
|
|
||||||
self.lm_head = new_embeddings
|
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
|
||||||
return self.lm_head
|
|
||||||
|
|
||||||
def get_encoder(self):
|
|
||||||
return self.encoder
|
|
||||||
|
|
||||||
def get_decoder(self):
|
|
||||||
return self.decoder
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
|
||||||
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
|
||||||
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
|
|
||||||
if head_mask is not None and decoder_head_mask is None:
|
|
||||||
if self.config.num_layers == self.config.num_decoder_layers:
|
|
||||||
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
|
|
||||||
decoder_head_mask = head_mask
|
|
||||||
|
|
||||||
# Encode if needed (training, first prediction pass)
|
|
||||||
if encoder_outputs is None:
|
|
||||||
# Convert encoder inputs in embeddings if needed
|
|
||||||
encoder_outputs = self.encoder(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
head_mask=head_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
elif return_dict and not isinstance(
|
|
||||||
encoder_outputs, AttentionBackboneModelOutput
|
|
||||||
):
|
|
||||||
encoder_outputs = AttentionBackboneModelOutput(
|
|
||||||
last_hidden_state=encoder_outputs[0],
|
|
||||||
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
|
||||||
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = encoder_outputs[0]
|
|
||||||
|
|
||||||
if self.model_parallel:
|
|
||||||
torch.cuda.set_device(self.decoder.first_device)
|
|
||||||
|
|
||||||
if (
|
|
||||||
labels is not None
|
|
||||||
and decoder_input_ids is None
|
|
||||||
and decoder_inputs_embeds is None
|
|
||||||
):
|
|
||||||
# get decoder inputs from shifting lm labels to the right
|
|
||||||
decoder_input_ids = self._shift_right(labels)
|
|
||||||
|
|
||||||
# Set device for model parallelism
|
|
||||||
if self.model_parallel:
|
|
||||||
torch.cuda.set_device(self.decoder.first_device)
|
|
||||||
hidden_states = hidden_states.to(self.decoder.first_device)
|
|
||||||
if decoder_input_ids is not None:
|
|
||||||
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attention_mask.to(self.decoder.first_device)
|
|
||||||
if decoder_attention_mask is not None:
|
|
||||||
decoder_attention_mask = decoder_attention_mask.to(
|
|
||||||
self.decoder.first_device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode
|
|
||||||
decoder_outputs = self.decoder(
|
|
||||||
input_ids=decoder_input_ids,
|
|
||||||
attention_mask=decoder_attention_mask,
|
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
encoder_hidden_states=hidden_states,
|
|
||||||
encoder_attention_mask=attention_mask,
|
|
||||||
head_mask=decoder_head_mask,
|
|
||||||
cross_attn_head_mask=cross_attn_head_mask,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
sequence_output = decoder_outputs[0]
|
|
||||||
|
|
||||||
# Set device for model parallelism
|
|
||||||
if self.model_parallel:
|
|
||||||
torch.cuda.set_device(self.encoder.first_device)
|
|
||||||
self.lm_head = self.lm_head.to(self.encoder.first_device)
|
|
||||||
sequence_output = sequence_output.to(self.lm_head.weight.device)
|
|
||||||
|
|
||||||
if self.config.tie_word_embeddings:
|
|
||||||
# Rescale output before projecting on vocab See
|
|
||||||
# https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
|
||||||
sequence_output = sequence_output * (self.model_dim ** -0.5)
|
|
||||||
|
|
||||||
lm_logits = self.lm_head(sequence_output)
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
|
||||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
|
||||||
# TODO(thom): Add z_loss
|
|
||||||
# https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
|
||||||
return ((loss,) + output) if loss is not None else output
|
|
||||||
|
|
||||||
return Seq2SeqLMOutput(
|
|
||||||
loss=loss,
|
|
||||||
logits=lm_logits,
|
|
||||||
past_key_values=decoder_outputs.past_key_values,
|
|
||||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
|
||||||
decoder_attentions=decoder_outputs.attentions,
|
|
||||||
cross_attentions=decoder_outputs.cross_attentions,
|
|
||||||
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
|
||||||
encoder_hidden_states=encoder_outputs.hidden_states,
|
|
||||||
encoder_attentions=encoder_outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
past=None,
|
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
|
||||||
decoder_head_mask=None,
|
|
||||||
cross_attn_head_mask=None,
|
|
||||||
use_cache=None,
|
|
||||||
encoder_outputs=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
|
|
||||||
# cut decoder_input_ids if past is used
|
|
||||||
if past is not None:
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"decoder_input_ids": input_ids,
|
|
||||||
"past_key_values": past,
|
|
||||||
"encoder_outputs": encoder_outputs,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"head_mask": head_mask,
|
|
||||||
"decoder_head_mask": decoder_head_mask,
|
|
||||||
"cross_attn_head_mask": cross_attn_head_mask,
|
|
||||||
"use_cache": use_cache,
|
|
||||||
}
|
|
||||||
|
|
||||||
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
|
||||||
return self._shift_right(labels)
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
output = super().generate(*args, **kwargs)
|
|
||||||
return TokenGeneratorOutput(
|
|
||||||
sequences=output if isinstance(output, torch.Tensor) else output[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
def chat(
|
|
||||||
self,
|
|
||||||
tokenizer,
|
|
||||||
query: str,
|
|
||||||
history: List[Tuple[str, str]] = None,
|
|
||||||
role: str = "user",
|
|
||||||
):
|
|
||||||
if history is None:
|
|
||||||
history = []
|
|
||||||
token = tokenizer(query)
|
|
||||||
inputs = torch.as_tensor([token["input_ids"]])
|
|
||||||
inputs_tensor = inputs.to(next(self.parameters()).device)
|
|
||||||
|
|
||||||
generation_config = copy.deepcopy(self.generation_config)
|
|
||||||
# inputs_tensor = inputs["input_ids"]
|
|
||||||
input_ids = inputs_tensor.repeat_interleave(
|
|
||||||
generation_config.num_return_sequences, dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = self.sample(
|
|
||||||
input_ids,
|
|
||||||
generation_config.pad_token_id,
|
|
||||||
generation_config.eos_token_id,
|
|
||||||
generation_config.output_hidden_states,
|
|
||||||
tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = outputs.tolist()[0][:]
|
|
||||||
response = tokenizer.decode(outputs)
|
|
||||||
history.append({"role": role, "content": query})
|
|
||||||
return response, history
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor,
|
|
||||||
pad_token_id: Optional[int] = None,
|
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
tokenizer=None,
|
|
||||||
):
|
|
||||||
if isinstance(eos_token_id, int):
|
|
||||||
eos_token_id = [eos_token_id]
|
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
|
|
||||||
|
|
||||||
isFinished = torch.zeros(
|
|
||||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
|
||||||
)
|
|
||||||
# token_count = 0
|
|
||||||
while True:
|
|
||||||
input_ids_in = input_ids
|
|
||||||
# batch_size, seq_length = input_ids_in.shape
|
|
||||||
# position_ids_in = (
|
|
||||||
# torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
|
||||||
# .unsqueeze(0)
|
|
||||||
# .repeat(batch_size, 1)
|
|
||||||
# )
|
|
||||||
# model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
|
|
||||||
|
|
||||||
# input_ids_in = self.prepare_inputs_for_generation(input_ids)
|
|
||||||
probs, next_tokens = self(input_ids)
|
|
||||||
# **model_inputs,
|
|
||||||
# output_hidden_states=output_hidden_states,
|
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# finished sentences should add a padding token to next
|
|
||||||
pad_token = pad_token_id * isFinished
|
|
||||||
next_tokens = next_tokens * (1 - isFinished) + pad_token
|
|
||||||
|
|
||||||
isFinished = isFinished | next_tokens.eq(eos_token_id_tensor)
|
|
||||||
if isFinished.min() == 1: # all batch is finish
|
|
||||||
break
|
|
||||||
|
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
|
||||||
return input_ids
|
|
||||||
|
|
||||||
def _reorder_cache(self, past, beam_idx):
|
|
||||||
# if decoder past is not included in output
|
|
||||||
# speedy decoding is disabled and no need to reorder
|
|
||||||
if past is None:
|
|
||||||
logger.warning(
|
|
||||||
"You might want to consider setting `use_cache=True` to speed up decoding"
|
|
||||||
)
|
|
||||||
return past
|
|
||||||
|
|
||||||
reordered_decoder_past = ()
|
|
||||||
for layer_past_states in past:
|
|
||||||
# get the correct batch idx from layer past batch dim
|
|
||||||
# batch dim of `past` is at 2nd position
|
|
||||||
reordered_layer_past_states = ()
|
|
||||||
for layer_past_state in layer_past_states:
|
|
||||||
# need to set correct `past` for each of the four key / value states
|
|
||||||
reordered_layer_past_states = reordered_layer_past_states + (
|
|
||||||
layer_past_state.index_select(
|
|
||||||
0, beam_idx.to(layer_past_state.device)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|
|
||||||
assert len(reordered_layer_past_states) == len(layer_past_states)
|
|
||||||
|
|
||||||
reordered_decoder_past = reordered_decoder_past + (
|
|
||||||
reordered_layer_past_states,
|
|
||||||
)
|
|
||||||
return reordered_decoder_past
|
|
|
@ -1,33 +0,0 @@
|
||||||
{
|
|
||||||
"_name_or_path": "./bloomz_560m_pretrained",
|
|
||||||
"apply_residual_connection_post_layernorm": false,
|
|
||||||
"architectures": [
|
|
||||||
"BloomForCausalLM"
|
|
||||||
],
|
|
||||||
"attention_dropout": 0.0,
|
|
||||||
"attention_softmax_in_fp32": true,
|
|
||||||
"bias_dropout_fusion": true,
|
|
||||||
"bos_token_id": 1,
|
|
||||||
"eos_token_id": 2,
|
|
||||||
"hidden_dropout": 0.0,
|
|
||||||
"hidden_size": 1024,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"layer_norm_epsilon": 1e-05,
|
|
||||||
"masked_softmax_fusion": true,
|
|
||||||
"model_type": "bloom",
|
|
||||||
"n_head": 16,
|
|
||||||
"n_inner": null,
|
|
||||||
"n_layer": 24,
|
|
||||||
"offset_alibi": 100,
|
|
||||||
"pad_token_id": 3,
|
|
||||||
"pretraining_tp": 1,
|
|
||||||
"seq_length": 2048,
|
|
||||||
"skip_bias_add": true,
|
|
||||||
"skip_bias_add_qkv": false,
|
|
||||||
"slow_but_exact": false,
|
|
||||||
"torch_dtype": "float16",
|
|
||||||
"transformers_version": "4.28.1",
|
|
||||||
"unk_token_id": 0,
|
|
||||||
"use_cache": true,
|
|
||||||
"vocab_size": 250880
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
{
|
|
||||||
"framework": "pytorch",
|
|
||||||
"task": "text-generation",
|
|
||||||
"model": {
|
|
||||||
"type": "bloom"
|
|
||||||
},
|
|
||||||
"pipeline": {
|
|
||||||
"type": "seqgpt"
|
|
||||||
},
|
|
||||||
"allow_remote": true
|
|
||||||
}
|
|
|
@ -1,242 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
""" Bloom configuration"""
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
|
|
||||||
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from ... import PreTrainedTokenizer, TensorType
|
|
||||||
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.onnx import OnnxConfigWithPast, PatchingSpec
|
|
||||||
from transformers.utils import is_torch_available, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|
||||||
"bigscience/bloom": "https://huggingface.co/bigscience/bloom/resolve/main/config.json",
|
|
||||||
"bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/config.json",
|
|
||||||
"bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/config.json",
|
|
||||||
"bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/config.json",
|
|
||||||
"bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/config.json",
|
|
||||||
"bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/config.json",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BloomConfig(PretrainedConfig):
|
|
||||||
"""
|
|
||||||
This is the configuration class to store the configuration of a [`BloomModel`]. It is used to instantiate a Bloom
|
|
||||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
|
||||||
defaults will yield a similar configuration to the Bloom architecture
|
|
||||||
[bigscience/bloom](https://huggingface.co/bigscience/bloom).
|
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
|
||||||
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab_size (`int`, *optional*, defaults to 250880):
|
|
||||||
Vocabulary size of the Bloom model. Defines the maximum number of different tokens that can be represented
|
|
||||||
by the `inputs_ids` passed when calling [`BloomModel`]. Check [this
|
|
||||||
discussion](https://huggingface.co/bigscience/bloom/discussions/120#633d28389addb8530b406c2a) on how the
|
|
||||||
`vocab_size` has been defined.
|
|
||||||
hidden_size (`int`, *optional*, defaults to 64):
|
|
||||||
Dimensionality of the embeddings and hidden states.
|
|
||||||
n_layer (`int`, *optional*, defaults to 2):
|
|
||||||
Number of hidden layers in the Transformer encoder.
|
|
||||||
n_head (`int`, *optional*, defaults to 8):
|
|
||||||
Number of attention heads for each attention layer in the Transformer encoder.
|
|
||||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
|
||||||
The epsilon to use in the layer normalization layers.
|
|
||||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
|
||||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
|
||||||
apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`):
|
|
||||||
If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
|
|
||||||
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
|
||||||
Dropout rate of the dropout function on the bias dropout.
|
|
||||||
attention_dropout (`float`, *optional*, defaults to 0.1):
|
|
||||||
Dropout rate applied to the attention probs
|
|
||||||
use_cache (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
|
||||||
pretraining_tp (`int`, *optional*, defaults to `1`):
|
|
||||||
Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
|
|
||||||
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
|
||||||
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
|
||||||
issue](https://github.com/pytorch/pytorch/issues/76232). Note also that this is enabled only when
|
|
||||||
`slow_but_exact=True`.
|
|
||||||
slow_but_exact (`bool`, *optional*, defaults to `False`):
|
|
||||||
Experimental feature. Whether to use slow but exact implementation of the attention mechanism. While
|
|
||||||
merging the TP rank tensors, due to slicing operations the results may be slightly different between the
|
|
||||||
model trained on Megatron and our model. Please refer to [this
|
|
||||||
issue](https://github.com/pytorch/pytorch/issues/76232). A solution to obtain more accurate results is to
|
|
||||||
enable this feature. Enabling this will hurt the computational time of the inference. Will be probably
|
|
||||||
resolved in the future once the main model has been fine-tuned with TP_rank=1.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import BloomConfig, BloomModel
|
|
||||||
|
|
||||||
>>> # Initializing a Bloom configuration
|
|
||||||
>>> configuration = BloomConfig()
|
|
||||||
|
|
||||||
>>> # Initializing a model (with random weights) from the configuration
|
|
||||||
>>> model = BloomModel(configuration)
|
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
|
||||||
>>> configuration = model.config
|
|
||||||
```"""
|
|
||||||
|
|
||||||
model_type = "bloom"
|
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
|
||||||
attribute_map = {
|
|
||||||
"num_hidden_layers": "n_layer",
|
|
||||||
"num_attention_heads": "n_head",
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size=250880,
|
|
||||||
hidden_size=64,
|
|
||||||
n_layer=2,
|
|
||||||
n_head=8,
|
|
||||||
layer_norm_epsilon=1e-5,
|
|
||||||
initializer_range=0.02,
|
|
||||||
use_cache=True,
|
|
||||||
bos_token_id=1,
|
|
||||||
eos_token_id=2,
|
|
||||||
apply_residual_connection_post_layernorm=False,
|
|
||||||
hidden_dropout=0.0,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
pretraining_tp=1, # TP rank used when training with megatron
|
|
||||||
slow_but_exact=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
# Backward compatibility with n_embed kwarg
|
|
||||||
n_embed = kwargs.pop("n_embed", None)
|
|
||||||
self.hidden_size = hidden_size if n_embed is None else n_embed
|
|
||||||
self.n_layer = n_layer
|
|
||||||
self.n_head = n_head
|
|
||||||
self.layer_norm_epsilon = layer_norm_epsilon
|
|
||||||
self.initializer_range = initializer_range
|
|
||||||
self.use_cache = use_cache
|
|
||||||
self.pretraining_tp = pretraining_tp
|
|
||||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
|
||||||
self.hidden_dropout = hidden_dropout
|
|
||||||
self.attention_dropout = attention_dropout
|
|
||||||
|
|
||||||
self.bos_token_id = bos_token_id
|
|
||||||
self.eos_token_id = eos_token_id
|
|
||||||
self.slow_but_exact = slow_but_exact
|
|
||||||
|
|
||||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class BloomOnnxConfig(OnnxConfigWithPast):
|
|
||||||
torch_onnx_minimum_version = version.parse("1.12")
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PretrainedConfig,
|
|
||||||
task: str = "default",
|
|
||||||
patching_specs: List[PatchingSpec] = None,
|
|
||||||
use_past: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
|
||||||
if not getattr(self._config, "pad_token_id", None):
|
|
||||||
# TODO: how to do that better?
|
|
||||||
self._config.pad_token_id = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
|
||||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
|
||||||
if self.use_past:
|
|
||||||
# BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
|
|
||||||
self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
|
|
||||||
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
|
||||||
else:
|
|
||||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
|
||||||
|
|
||||||
return common_inputs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_layers(self) -> int:
|
|
||||||
return self._config.n_layer
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_attention_heads(self) -> int:
|
|
||||||
return self._config.n_head
|
|
||||||
|
|
||||||
@property
|
|
||||||
def atol_for_validation(self) -> float:
|
|
||||||
return 1e-3
|
|
||||||
|
|
||||||
def generate_dummy_inputs(
|
|
||||||
self,
|
|
||||||
tokenizer: "PreTrainedTokenizer",
|
|
||||||
batch_size: int = -1,
|
|
||||||
seq_length: int = -1,
|
|
||||||
is_pair: bool = False,
|
|
||||||
framework: Optional["TensorType"] = None,
|
|
||||||
) -> Mapping[str, Any]:
|
|
||||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
|
||||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
|
||||||
)
|
|
||||||
|
|
||||||
# We need to order the input in the way they appears in the forward()
|
|
||||||
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
|
||||||
|
|
||||||
# Need to add the past_keys
|
|
||||||
if self.use_past:
|
|
||||||
if not is_torch_available():
|
|
||||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
|
||||||
else:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
batch, seqlen = common_inputs["input_ids"].shape
|
|
||||||
# Not using the same length for past_key_values
|
|
||||||
past_key_values_length = seqlen + 2
|
|
||||||
head_dim = self._config.hidden_size // self.num_attention_heads
|
|
||||||
past_key_shape = (
|
|
||||||
batch * self.num_attention_heads,
|
|
||||||
head_dim,
|
|
||||||
past_key_values_length,
|
|
||||||
)
|
|
||||||
past_value_shape = (
|
|
||||||
batch * self.num_attention_heads,
|
|
||||||
past_key_values_length,
|
|
||||||
head_dim,
|
|
||||||
)
|
|
||||||
ordered_inputs["past_key_values"] = [
|
|
||||||
(torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
|
|
||||||
]
|
|
||||||
|
|
||||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
|
||||||
if self.use_past:
|
|
||||||
mask_dtype = ordered_inputs["attention_mask"].dtype
|
|
||||||
ordered_inputs["attention_mask"] = torch.cat(
|
|
||||||
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
return ordered_inputs
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_onnx_opset(self) -> int:
|
|
||||||
return 13
|
|
|
@ -1,255 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Convert BigScience BLOOM checkpoint."""
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from transformers import BloomConfig, BloomModel
|
|
||||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
|
||||||
|
|
||||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
|
|
||||||
"word_embeddings_layernorm.weight",
|
|
||||||
"word_embeddings_layernorm.bias",
|
|
||||||
"input_layernorm.weight",
|
|
||||||
"input_layernorm.bias",
|
|
||||||
"post_attention_layernorm.weight",
|
|
||||||
"post_attention_layernorm.bias",
|
|
||||||
"self_attention.dense.bias",
|
|
||||||
"mlp.dense_4h_to_h.bias",
|
|
||||||
"ln_f.weight",
|
|
||||||
"ln_f.bias",
|
|
||||||
]
|
|
||||||
|
|
||||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
|
|
||||||
"mlp.dense_4h_to_h.weight",
|
|
||||||
"self_attention.dense.weight",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def layer_name_mapping(key, file):
|
|
||||||
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
|
|
||||||
# Handle first and last layers
|
|
||||||
layer_rename_map = {
|
|
||||||
"word_embeddings.weight": "word_embeddings.weight",
|
|
||||||
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
|
|
||||||
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
|
|
||||||
"weight": "ln_f.weight",
|
|
||||||
"bias": "ln_f.bias",
|
|
||||||
}
|
|
||||||
|
|
||||||
if key in layer_rename_map:
|
|
||||||
return layer_rename_map[key]
|
|
||||||
|
|
||||||
# Handle transformer blocks
|
|
||||||
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
|
|
||||||
layer_number -= 3
|
|
||||||
return f"h.{layer_number}." + key
|
|
||||||
|
|
||||||
|
|
||||||
def get_dtype_size(dtype):
|
|
||||||
if dtype == torch.bool:
|
|
||||||
return 1 / 8
|
|
||||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
|
||||||
if bit_search is None:
|
|
||||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
|
||||||
bit_size = int(bit_search.groups()[0])
|
|
||||||
return bit_size // 8
|
|
||||||
|
|
||||||
|
|
||||||
def convert_bloom_checkpoint_to_pytorch(
|
|
||||||
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
|
|
||||||
):
|
|
||||||
# Construct model
|
|
||||||
if bloom_config_file == "":
|
|
||||||
config = BloomConfig()
|
|
||||||
else:
|
|
||||||
config = BloomConfig.from_json_file(bloom_config_file)
|
|
||||||
|
|
||||||
if shard_model:
|
|
||||||
file_names = os.listdir(bloom_checkpoint_path)
|
|
||||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
|
||||||
|
|
||||||
index_dict = {"weight_map": {}, "metadata": {}}
|
|
||||||
total_size = 0
|
|
||||||
|
|
||||||
missing_keys = None
|
|
||||||
|
|
||||||
config = BloomConfig()
|
|
||||||
|
|
||||||
for j, file in enumerate(file_names):
|
|
||||||
print("Processing file: {}".format(file))
|
|
||||||
tensors = None
|
|
||||||
|
|
||||||
for i in range(pretraining_tp):
|
|
||||||
# load all TP files
|
|
||||||
f_name = file.replace("model_00", f"model_0{i}")
|
|
||||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
|
||||||
|
|
||||||
# Rename keys in the transformers names
|
|
||||||
keys = list(temp.keys())
|
|
||||||
for key in keys:
|
|
||||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
|
||||||
|
|
||||||
if tensors is None:
|
|
||||||
tensors = temp
|
|
||||||
else:
|
|
||||||
for key in tensors.keys():
|
|
||||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
|
||||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
|
||||||
tensors[key] += temp[key]
|
|
||||||
else:
|
|
||||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
|
||||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
|
||||||
# We concatenate these weights accross TP ranks
|
|
||||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
|
||||||
|
|
||||||
# Divide by the number of TP the weights we want to average
|
|
||||||
for key in tensors.keys():
|
|
||||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
|
||||||
tensors[key] = tensors[key] / pretraining_tp
|
|
||||||
torch.save(
|
|
||||||
tensors,
|
|
||||||
os.path.join(
|
|
||||||
pytorch_dump_folder_path,
|
|
||||||
"pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
for key in tensors.keys():
|
|
||||||
value = tensors[key]
|
|
||||||
total_size += value.numel() * get_dtype_size(value.dtype)
|
|
||||||
if key not in index_dict["weight_map"]:
|
|
||||||
index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
|
|
||||||
str(j + 1).zfill(5), str(len(file_names)).zfill(5)
|
|
||||||
)
|
|
||||||
|
|
||||||
config = BloomConfig()
|
|
||||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
|
||||||
index_dict["metadata"]["total_size"] = total_size
|
|
||||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(config.to_json_string())
|
|
||||||
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
|
|
||||||
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
|
|
||||||
f.write(json_config)
|
|
||||||
else:
|
|
||||||
model = BloomModel(config)
|
|
||||||
|
|
||||||
file_names = os.listdir(bloom_checkpoint_path)
|
|
||||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
|
||||||
|
|
||||||
missing_keys = None
|
|
||||||
for i, file in enumerate(file_names):
|
|
||||||
tensors = None
|
|
||||||
for i in range(pretraining_tp):
|
|
||||||
# load all TP files
|
|
||||||
f_name = file.replace("model_00", f"model_0{i}")
|
|
||||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
|
||||||
|
|
||||||
# Rename keys in the transformers names
|
|
||||||
keys = list(temp.keys())
|
|
||||||
for key in keys:
|
|
||||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
|
||||||
|
|
||||||
if tensors is None:
|
|
||||||
tensors = temp
|
|
||||||
else:
|
|
||||||
for key in tensors.keys():
|
|
||||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
|
||||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
|
||||||
tensors[key] += temp[key]
|
|
||||||
else:
|
|
||||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
|
||||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
|
||||||
# We concatenate these weights accross TP ranks
|
|
||||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
|
||||||
|
|
||||||
# Divide by the number of TP the weights we want to average
|
|
||||||
for key in tensors.keys():
|
|
||||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
|
||||||
tensors[key] = tensors[key] / pretraining_tp
|
|
||||||
|
|
||||||
other_keys = model.load_state_dict(tensors, strict=False)
|
|
||||||
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
|
|
||||||
if missing_keys is None:
|
|
||||||
missing_keys = set(other_keys.missing_keys)
|
|
||||||
else:
|
|
||||||
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
|
|
||||||
|
|
||||||
assert not missing_keys, f"The keys {missing_keys} are missing"
|
|
||||||
|
|
||||||
# Save pytorch-model
|
|
||||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
|
||||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
|
||||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
|
||||||
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
|
|
||||||
if config.torch_dtype is not None:
|
|
||||||
model = model.to(config.torch_dtype)
|
|
||||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
|
||||||
print(f"Save configuration file to {pytorch_config_dump_path}")
|
|
||||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(config.to_json_string())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
# Required parameters
|
|
||||||
parser.add_argument(
|
|
||||||
"--bloom_checkpoint_path",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Path to the Megatron-LM checkpoint path.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--bloom_config_file",
|
|
||||||
default="",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"An optional config json file corresponding to the pre-trained model. \n"
|
|
||||||
"This specifies the model architecture."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--shard_model",
|
|
||||||
action="store_true",
|
|
||||||
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretraining_tp",
|
|
||||||
default=4,
|
|
||||||
type=int,
|
|
||||||
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
convert_bloom_checkpoint_to_pytorch(
|
|
||||||
args.bloom_checkpoint_path,
|
|
||||||
args.bloom_config_file,
|
|
||||||
args.pytorch_dump_folder_path,
|
|
||||||
args.shard_model,
|
|
||||||
args.pretraining_tp,
|
|
||||||
)
|
|
|
@ -1,69 +0,0 @@
|
||||||
# from modelscope.utils.constant import Tasks
|
|
||||||
# from modelscope.pipelines import pipeline
|
|
||||||
|
|
||||||
# prompt = "输入: 中国的首都在哪里\n输出: "
|
|
||||||
|
|
||||||
# # task可选值为 抽取、分类。text为需要分析的文本。labels为类型列表,中文逗号分隔。
|
|
||||||
# inputs = {'task': '抽取', 'text': '杭州欢迎你。', 'labels': '地名'}
|
|
||||||
# # PROMPT_TEMPLATE保持不变
|
|
||||||
# PROMPT_TEMPLATE = '输入: {text}\n{task}: {labels}\n输出: '
|
|
||||||
# # prompt = PROMPT_TEMPLATE.format(**inputs)
|
|
||||||
# pipeline_ins = pipeline(task=Tasks.text_generation, model='damo/nlp_seqgpt-560m', model_revision = 'v1.0.1', run_kwargs={'gen_token': '[GEN]'})
|
|
||||||
# print(pipeline_ins(prompt))
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import json
|
|
||||||
from modelscope.pipelines import pipeline
|
|
||||||
from modelscope.utils.constant import Tasks
|
|
||||||
|
|
||||||
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
|
|
||||||
|
|
||||||
from modelscope.utils.config import Config
|
|
||||||
from modelscope import snapshot_download
|
|
||||||
from transformers import AutoConfig
|
|
||||||
|
|
||||||
from modeling_bloom import BloomForCausalLM
|
|
||||||
from tokenization_bloom_fast import BloomTokenizerFast
|
|
||||||
|
|
||||||
seed = 4321
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
|
|
||||||
model_dir = snapshot_download("damo/nlp_seqgpt-560m")
|
|
||||||
|
|
||||||
config, kwargs = AutoConfig.from_pretrained(
|
|
||||||
model_dir,
|
|
||||||
return_unused_kwargs=True,
|
|
||||||
trust_remote_code=True,
|
|
||||||
code_revision=None,
|
|
||||||
_commit_hash=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer_config_file = "./tokenizer_config.json"
|
|
||||||
if tokenizer_config_file is not None:
|
|
||||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
|
||||||
init_kwargs = json.load(tokenizer_config_handle)
|
|
||||||
init_kwargs.pop("tokenizer_class", None)
|
|
||||||
init_kwargs.pop("tokenizer_file", None)
|
|
||||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
|
||||||
init_inputs = saved_init_inputs
|
|
||||||
init_kwargs["vocab_file"] = None
|
|
||||||
init_kwargs["added_tokens_file"] = None
|
|
||||||
init_kwargs["special_tokens_map_file"] = "./special_tokens_map.json"
|
|
||||||
init_kwargs["tokenizer_file"] = "./tokenizer.json"
|
|
||||||
init_kwargs["name_or_path"] = model_dir
|
|
||||||
tokenizer = BloomTokenizerFast(*init_inputs, **init_kwargs)
|
|
||||||
|
|
||||||
model = BloomForCausalLM(config)
|
|
||||||
model = model.from_pretrained(model_dir).cuda().train()
|
|
||||||
|
|
||||||
prompt = "输入: 中国的首都在哪里\n输出: "
|
|
||||||
prompt = "输入: 美国的首都在哪里\n输出: "
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
|
||||||
input_ids = input_ids.input_ids.cuda()
|
|
||||||
outputs = model.generate(input_ids, num_beams=4, do_sample=False, max_new_tokens=256)
|
|
||||||
decoded_sentences = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
||||||
print(decoded_sentences[0])
|
|
|
@ -1,7 +0,0 @@
|
||||||
{
|
|
||||||
"_from_model_config": true,
|
|
||||||
"bos_token_id": 1,
|
|
||||||
"eos_token_id": 2,
|
|
||||||
"pad_token_id": 3,
|
|
||||||
"transformers_version": "4.28.1"
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,734 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2023 HuggingFace Inc. Team and Bigscience Workshop. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Flax BLOOM model."""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from functools import partial
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import flax.linen as nn
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
|
||||||
from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask
|
|
||||||
from flax.linen.activation import tanh
|
|
||||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
|
||||||
from jax import lax
|
|
||||||
|
|
||||||
from ...modeling_flax_outputs import (
|
|
||||||
FlaxBaseModelOutput,
|
|
||||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
|
||||||
FlaxCausalLMOutput,
|
|
||||||
)
|
|
||||||
from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring
|
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
|
||||||
from .configuration_bloom import BloomConfig
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "bigscience/bloom"
|
|
||||||
_CONFIG_FOR_DOC = "BloomConfig"
|
|
||||||
|
|
||||||
|
|
||||||
BLOOM_START_DOCSTRING = r"""
|
|
||||||
|
|
||||||
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
||||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
||||||
etc.)
|
|
||||||
|
|
||||||
This model is also a Flax Linen
|
|
||||||
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
|
||||||
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
|
||||||
|
|
||||||
Finally, this model supports inherent JAX features such as:
|
|
||||||
|
|
||||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
|
||||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
|
||||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
|
||||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
|
|
||||||
Initializing with a config file does not load the weights associated with the model, only the
|
|
||||||
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
|
||||||
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
|
||||||
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
|
||||||
`jax.numpy.bfloat16` (on TPUs).
|
|
||||||
|
|
||||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
|
||||||
specified all the computation will be performed with the given `dtype`.
|
|
||||||
|
|
||||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
|
||||||
parameters.**
|
|
||||||
|
|
||||||
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
|
||||||
[`~FlaxPreTrainedModel.to_bf16`].
|
|
||||||
"""
|
|
||||||
|
|
||||||
BLOOM_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
|
||||||
input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
|
|
||||||
`input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
|
|
||||||
|
|
||||||
Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
|
||||||
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
|
||||||
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def build_alibi_tensor(attention_mask: jnp.ndarray, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32):
|
|
||||||
"""
|
|
||||||
Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it
|
|
||||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
|
||||||
`softmax(l+a) = softmax(l)`. Based on
|
|
||||||
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
|
||||||
Link to paper: https://arxiv.org/abs/2108.12409
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attention_mask (`jnp.ndarray`):
|
|
||||||
Token-wise attention mask, this should be of shape `(batch_size, max_seq_len)`.
|
|
||||||
num_heads (`int`):
|
|
||||||
Number of attention heads.
|
|
||||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
|
||||||
The data type (dtype) of the output tensor.
|
|
||||||
|
|
||||||
Returns: Alibi tensor of shape `(batch_size * num_heads, 1, max_seq_len)`.
|
|
||||||
"""
|
|
||||||
batch_size, seq_length = attention_mask.shape
|
|
||||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
|
||||||
base = jnp.array(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=jnp.float32)
|
|
||||||
powers = jnp.arange(1, 1 + closest_power_of_2, dtype=jnp.float32)
|
|
||||||
slopes = jax.lax.pow(base, powers)
|
|
||||||
|
|
||||||
if closest_power_of_2 != num_heads:
|
|
||||||
extra_base = jnp.array(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=jnp.float32)
|
|
||||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
|
||||||
extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32)
|
|
||||||
slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0)
|
|
||||||
|
|
||||||
# Note: the Alibi tensor will added to the attention bias that will be applied to the query, key product of attention
|
|
||||||
# therefore, Alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
|
||||||
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
|
||||||
# so that the query_length dimension will then be broadcast correctly.
|
|
||||||
# This is more or less identical to T5's relative position bias:
|
|
||||||
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
|
||||||
arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :]
|
|
||||||
alibi = slopes[..., None] * arange_tensor
|
|
||||||
alibi = jnp.expand_dims(alibi, axis=2)
|
|
||||||
return jnp.asarray(alibi, dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBloomAttention(nn.Module):
|
|
||||||
config: BloomConfig
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.hidden_size = self.config.hidden_size
|
|
||||||
self.num_heads = self.config.n_head
|
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
|
||||||
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
|
|
||||||
|
|
||||||
if self.head_dim * self.num_heads != self.hidden_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and "
|
|
||||||
f"`num_heads`: {self.num_heads})."
|
|
||||||
)
|
|
||||||
|
|
||||||
dense = partial(
|
|
||||||
nn.Dense,
|
|
||||||
dtype=self.dtype,
|
|
||||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.query_key_value = dense(self.hidden_size * 3)
|
|
||||||
self.dense = dense(self.hidden_size)
|
|
||||||
self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
|
||||||
|
|
||||||
def _split_heads(self, hidden_states):
|
|
||||||
return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, self.head_dim * 3))
|
|
||||||
|
|
||||||
def _merge_heads(self, hidden_states):
|
|
||||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
|
|
||||||
|
|
||||||
@nn.compact
|
|
||||||
# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache
|
|
||||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
|
||||||
"""
|
|
||||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
|
||||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
|
||||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
|
||||||
"""
|
|
||||||
# detect if we're initializing by absence of existing cache data.
|
|
||||||
is_initialized = self.has_variable("cache", "cached_key")
|
|
||||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
|
||||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
|
||||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
|
||||||
|
|
||||||
if is_initialized:
|
|
||||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
|
||||||
# update key, value caches with our new 1d spatial slices
|
|
||||||
cur_index = cache_index.value
|
|
||||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
|
||||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
|
||||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
|
||||||
cached_key.value = key
|
|
||||||
cached_value.value = value
|
|
||||||
num_updated_cache_vectors = query.shape[1]
|
|
||||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
|
||||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key
|
|
||||||
# positions that have already been generated and cached, not the remaining zero elements.
|
|
||||||
pad_mask = jnp.broadcast_to(
|
|
||||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
|
||||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
|
||||||
)
|
|
||||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
|
||||||
return key, value, attention_mask
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
alibi,
|
|
||||||
attention_mask=None,
|
|
||||||
deterministic: bool = True,
|
|
||||||
init_cache: bool = False,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
):
|
|
||||||
batch_size, seq_length = hidden_states.shape[:2]
|
|
||||||
|
|
||||||
# proj q, k, v
|
|
||||||
fused_qkv = self.query_key_value(hidden_states)
|
|
||||||
fused_qkv = self._split_heads(fused_qkv)
|
|
||||||
query, key, value = jnp.split(fused_qkv, 3, axis=-1)
|
|
||||||
|
|
||||||
causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
|
|
||||||
|
|
||||||
# for fast decoding causal attention mask should be shifted
|
|
||||||
causal_attention_mask_shift = (
|
|
||||||
self.variables["cache"]["cache_index"] if self.has_variable("cache", "cached_key") else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
# fast decoding for generate requires special attention_mask
|
|
||||||
if self.has_variable("cache", "cached_key"):
|
|
||||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
|
||||||
causal_attention_mask = jax.lax.dynamic_slice(
|
|
||||||
causal_attention_mask,
|
|
||||||
(0, 0, causal_attention_mask_shift, 0),
|
|
||||||
(1, 1, seq_length, max_decoder_length),
|
|
||||||
)
|
|
||||||
|
|
||||||
# broadcast causal attention mask & attention mask to fit for merge
|
|
||||||
causal_attention_mask = jnp.broadcast_to(
|
|
||||||
causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
|
|
||||||
)
|
|
||||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape)
|
|
||||||
attention_mask = combine_masks(attention_mask, causal_attention_mask)
|
|
||||||
|
|
||||||
dropout_rng = None
|
|
||||||
if not deterministic and self.config.attention_dropout > 0.0:
|
|
||||||
dropout_rng = self.make_rng("dropout")
|
|
||||||
|
|
||||||
# During fast autoregressive decoding, we feed one position at a time,
|
|
||||||
# and cache the keys and values step by step.
|
|
||||||
if self.has_variable("cache", "cached_key") or init_cache:
|
|
||||||
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
|
|
||||||
|
|
||||||
# transform boolean mask into float mask
|
|
||||||
mask_value = jnp.finfo(self.dtype).min
|
|
||||||
attention_bias = lax.select(
|
|
||||||
attention_mask > 0,
|
|
||||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
|
||||||
jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_bias = attention_bias + alibi
|
|
||||||
|
|
||||||
# Cast in fp32 if the original dtype is different from fp32
|
|
||||||
attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
|
|
||||||
|
|
||||||
attn_weights = dot_product_attention_weights(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
bias=attention_bias,
|
|
||||||
dropout_rng=dropout_rng,
|
|
||||||
dropout_rate=self.config.attention_dropout,
|
|
||||||
deterministic=deterministic,
|
|
||||||
dtype=attention_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Cast back in the original dtype if the native dtype is not fp32
|
|
||||||
if self.attention_softmax_in_fp32:
|
|
||||||
attn_weights = attn_weights.astype(self.dtype)
|
|
||||||
|
|
||||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
|
|
||||||
attn_output = self._merge_heads(attn_output)
|
|
||||||
attn_output = self.dense(attn_output)
|
|
||||||
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
|
||||||
|
|
||||||
attn_output = attn_output + residual
|
|
||||||
|
|
||||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class BloomGELU(nn.Module):
|
|
||||||
def setup(self):
|
|
||||||
self.dtype = jnp.float32
|
|
||||||
|
|
||||||
def __call__(self, x):
|
|
||||||
return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBloomMLP(nn.Module):
|
|
||||||
config: BloomConfig
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
hidden_size = self.config.hidden_size
|
|
||||||
|
|
||||||
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
|
|
||||||
|
|
||||||
self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init)
|
|
||||||
self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init)
|
|
||||||
self.hidden_dropout = nn.Dropout(self.config.hidden_dropout)
|
|
||||||
self.act = BloomGELU()
|
|
||||||
|
|
||||||
def __call__(self, hidden_states, residual, deterministic: bool = True):
|
|
||||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
|
||||||
hidden_states = self.act(hidden_states)
|
|
||||||
|
|
||||||
intermediate_output = self.dense_4h_to_h(hidden_states)
|
|
||||||
|
|
||||||
intermediate_output = intermediate_output + residual
|
|
||||||
hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBloomBlock(nn.Module):
|
|
||||||
config: BloomConfig
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
|
||||||
|
|
||||||
self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype)
|
|
||||||
self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
|
||||||
|
|
||||||
self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype)
|
|
||||||
|
|
||||||
self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm
|
|
||||||
self.hidden_dropout = self.config.hidden_dropout
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
alibi,
|
|
||||||
attention_mask=None,
|
|
||||||
deterministic: bool = True,
|
|
||||||
init_cache: bool = False,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
):
|
|
||||||
layernorm_output = self.input_layernorm(hidden_states)
|
|
||||||
|
|
||||||
# layer norm before saving residual if config calls for it
|
|
||||||
if self.apply_residual_connection_post_layernorm:
|
|
||||||
residual = layernorm_output
|
|
||||||
else:
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
# self-attention
|
|
||||||
attn_outputs = self.self_attention(
|
|
||||||
layernorm_output,
|
|
||||||
residual=residual,
|
|
||||||
alibi=alibi,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
deterministic=deterministic,
|
|
||||||
init_cache=init_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
attention_output = attn_outputs[0]
|
|
||||||
|
|
||||||
outputs = attn_outputs[1:]
|
|
||||||
|
|
||||||
post_layernorm = self.post_attention_layernorm(attention_output)
|
|
||||||
|
|
||||||
# set residual based on config
|
|
||||||
if self.apply_residual_connection_post_layernorm:
|
|
||||||
residual = post_layernorm
|
|
||||||
else:
|
|
||||||
residual = attention_output
|
|
||||||
|
|
||||||
output = self.mlp(post_layernorm, residual, deterministic=deterministic)
|
|
||||||
|
|
||||||
outputs = (output,) + outputs
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBloomPreTrainedModel(FlaxPreTrainedModel):
|
|
||||||
"""
|
|
||||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
||||||
models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
config_class = BloomConfig
|
|
||||||
base_model_prefix = "transformer"
|
|
||||||
module_class: nn.Module = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: BloomConfig,
|
|
||||||
input_shape: Tuple = (1, 1),
|
|
||||||
seed: int = 0,
|
|
||||||
dtype: jnp.dtype = jnp.float32,
|
|
||||||
_do_init: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
|
||||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
|
||||||
|
|
||||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
|
||||||
# init input tensors
|
|
||||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
|
||||||
params_rng, dropout_rng = jax.random.split(rng)
|
|
||||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
|
||||||
|
|
||||||
random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
|
|
||||||
|
|
||||||
if params is not None:
|
|
||||||
random_params = flatten_dict(unfreeze(random_params))
|
|
||||||
params = flatten_dict(unfreeze(params))
|
|
||||||
for missing_key in self._missing_keys:
|
|
||||||
params[missing_key] = random_params[missing_key]
|
|
||||||
self._missing_keys = set()
|
|
||||||
return freeze(unflatten_dict(params))
|
|
||||||
else:
|
|
||||||
return random_params
|
|
||||||
|
|
||||||
def init_cache(self, batch_size, max_length):
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
batch_size (`int`):
|
|
||||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
|
||||||
max_length (`int`):
|
|
||||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
|
||||||
cache.
|
|
||||||
"""
|
|
||||||
# init input variables to retrieve cache
|
|
||||||
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
|
||||||
attention_mask = jnp.ones_like(input_ids)
|
|
||||||
|
|
||||||
init_variables = self.module.init(
|
|
||||||
jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True
|
|
||||||
)
|
|
||||||
return unfreeze(init_variables["cache"])
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask=None,
|
|
||||||
past_key_values: dict = None,
|
|
||||||
params: dict = None,
|
|
||||||
dropout_rng: jax.random.PRNGKey = None,
|
|
||||||
train: bool = False,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
):
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
batch_size, sequence_length = input_ids.shape
|
|
||||||
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = jnp.ones((batch_size, sequence_length))
|
|
||||||
|
|
||||||
# Handle any PRNG if needed
|
|
||||||
rngs = {}
|
|
||||||
if dropout_rng is not None:
|
|
||||||
rngs["dropout"] = dropout_rng
|
|
||||||
|
|
||||||
inputs = {"params": params or self.params}
|
|
||||||
|
|
||||||
# If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
|
||||||
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
|
||||||
# changed by FlaxBloomAttention module
|
|
||||||
if past_key_values:
|
|
||||||
inputs["cache"] = past_key_values
|
|
||||||
mutable = ["cache"]
|
|
||||||
else:
|
|
||||||
mutable = False
|
|
||||||
|
|
||||||
outputs = self.module.apply(
|
|
||||||
inputs,
|
|
||||||
jnp.array(input_ids, dtype="i4"),
|
|
||||||
jnp.array(attention_mask, dtype="i4"),
|
|
||||||
not train,
|
|
||||||
False,
|
|
||||||
output_attentions,
|
|
||||||
output_hidden_states,
|
|
||||||
return_dict,
|
|
||||||
rngs=rngs,
|
|
||||||
mutable=mutable,
|
|
||||||
)
|
|
||||||
|
|
||||||
# add updated cache to model output
|
|
||||||
if past_key_values is not None and return_dict:
|
|
||||||
outputs, past_key_values = outputs
|
|
||||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
|
||||||
return outputs
|
|
||||||
elif past_key_values is not None and not return_dict:
|
|
||||||
outputs, past_key_values = outputs
|
|
||||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBloomBlockCollection(nn.Module):
|
|
||||||
config: BloomConfig
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.layers = [
|
|
||||||
FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype)
|
|
||||||
for layer_number in range(self.config.num_hidden_layers)
|
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
hidden_states,
|
|
||||||
alibi,
|
|
||||||
attention_mask=None,
|
|
||||||
deterministic: bool = True,
|
|
||||||
init_cache: bool = False,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
output_hidden_states: bool = False,
|
|
||||||
):
|
|
||||||
all_attentions = () if output_attentions else None
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
|
||||||
|
|
||||||
for layer_number in range(self.config.num_hidden_layers):
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states += (hidden_states,)
|
|
||||||
|
|
||||||
layer_outputs = self.layers[layer_number](
|
|
||||||
hidden_states,
|
|
||||||
alibi=alibi,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
deterministic=deterministic,
|
|
||||||
init_cache=init_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
hidden_states = layer_outputs[0]
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
all_attentions += (layer_outputs[1],)
|
|
||||||
|
|
||||||
# this contains possible `None` values - `FlaxBloomModule` will filter them out
|
|
||||||
outputs = (hidden_states, all_hidden_states, all_attentions)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBloomModule(nn.Module):
|
|
||||||
config: BloomConfig
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.embed_dim = self.config.hidden_size
|
|
||||||
|
|
||||||
# word embeddings (no positional embedding layer)
|
|
||||||
self.word_embeddings = nn.Embed(
|
|
||||||
self.config.vocab_size,
|
|
||||||
self.embed_dim,
|
|
||||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# post-embedding layernorm
|
|
||||||
self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
|
||||||
|
|
||||||
# transformer layers
|
|
||||||
self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype)
|
|
||||||
|
|
||||||
# final layernorm
|
|
||||||
self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
input_ids=None,
|
|
||||||
attention_mask=None,
|
|
||||||
deterministic=True,
|
|
||||||
init_cache: bool = False,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
output_hidden_states: bool = False,
|
|
||||||
return_dict: bool = True,
|
|
||||||
):
|
|
||||||
inputs_embeds = self.word_embeddings(input_ids)
|
|
||||||
# do post-embedding layernorm
|
|
||||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
||||||
|
|
||||||
# build alibi depending on `attention_mask`
|
|
||||||
alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
|
|
||||||
|
|
||||||
outputs = self.h(
|
|
||||||
hidden_states,
|
|
||||||
alibi=alibi,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
deterministic=deterministic,
|
|
||||||
init_cache=init_cache,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
|
||||||
|
|
||||||
if output_hidden_states:
|
|
||||||
all_hidden_states = outputs[1] + (hidden_states,)
|
|
||||||
outputs = (hidden_states, all_hidden_states) + outputs[2:]
|
|
||||||
else:
|
|
||||||
outputs = (hidden_states,) + outputs[1:]
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return tuple(v for v in [outputs[0], outputs[-1]] if v is not None)
|
|
||||||
|
|
||||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
|
||||||
last_hidden_state=hidden_states,
|
|
||||||
hidden_states=outputs[1],
|
|
||||||
attentions=outputs[-1],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
|
|
||||||
BLOOM_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom
|
|
||||||
class FlaxBloomModel(FlaxBloomPreTrainedModel):
|
|
||||||
module_class = FlaxBloomModule
|
|
||||||
|
|
||||||
|
|
||||||
append_call_sample_docstring(FlaxBloomModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
|
|
||||||
|
|
||||||
|
|
||||||
class FlaxBloomForCausalLMModule(nn.Module):
|
|
||||||
config: BloomConfig
|
|
||||||
dtype: jnp.dtype = jnp.float32
|
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
self.transformer = FlaxBloomModule(self.config, dtype=self.dtype)
|
|
||||||
self.lm_head = nn.Dense(
|
|
||||||
self.config.vocab_size,
|
|
||||||
use_bias=False,
|
|
||||||
dtype=self.dtype,
|
|
||||||
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
input_ids,
|
|
||||||
attention_mask,
|
|
||||||
deterministic: bool = True,
|
|
||||||
init_cache: bool = False,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
output_hidden_states: bool = False,
|
|
||||||
return_dict: bool = True,
|
|
||||||
):
|
|
||||||
outputs = self.transformer(
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
deterministic=deterministic,
|
|
||||||
init_cache=init_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
|
|
||||||
if self.config.tie_word_embeddings:
|
|
||||||
shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T
|
|
||||||
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
|
|
||||||
else:
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (lm_logits,) + outputs[1:]
|
|
||||||
|
|
||||||
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"""
|
|
||||||
The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
|
||||||
embeddings).
|
|
||||||
""",
|
|
||||||
BLOOM_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel):
|
|
||||||
module_class = FlaxBloomForCausalLMModule
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
|
||||||
# initializing the cache
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
|
|
||||||
past_key_values = self.init_cache(batch_size, max_length)
|
|
||||||
# Note that usually one would have to put 0's in the attention_mask for
|
|
||||||
# x > input_ids.shape[-1] and x < cache_length. But since Bloom uses a causal mask,
|
|
||||||
# those positions are masked anyway. Thus, we can create a single static attention_mask here,
|
|
||||||
# which is more efficient for compilation
|
|
||||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
|
||||||
if attention_mask is not None:
|
|
||||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"attention_mask": extended_attention_mask,
|
|
||||||
}
|
|
||||||
|
|
||||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
|
||||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
|
||||||
return model_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC)
|
|
|
@ -1,9 +0,0 @@
|
||||||
{
|
|
||||||
"additional_special_tokens": [
|
|
||||||
"[GEN]"
|
|
||||||
],
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
"pad_token": "<pad>",
|
|
||||||
"unk_token": "<unk>"
|
|
||||||
}
|
|
|
@ -1,177 +0,0 @@
|
||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 The HuggingFace Inc. team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Tokenization classes for Bloom."""
|
|
||||||
|
|
||||||
|
|
||||||
import pickle
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from transformers.tokenization_utils_base import BatchEncoding
|
|
||||||
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
|
|
||||||
|
|
||||||
PRETRAINED_VOCAB_FILES_MAP = {
|
|
||||||
"tokenizer_file": {
|
|
||||||
"bigscience/tokenizer": "https://huggingface.co/bigscience/tokenizer/blob/main/tokenizer.json",
|
|
||||||
"bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/tokenizer.json",
|
|
||||||
"bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/tokenizer.json",
|
|
||||||
"bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/tokenizer.json",
|
|
||||||
"bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/tokenizer.json",
|
|
||||||
"bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/tokenizer.json",
|
|
||||||
"bigscience/bloom": "https://huggingface.co/bigscience/bloom/blob/main/tokenizer.json",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BloomTokenizerFast(PreTrainedTokenizerFast):
|
|
||||||
"""
|
|
||||||
Construct a "fast" Bloom tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
|
|
||||||
Byte-Pair-Encoding.
|
|
||||||
|
|
||||||
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
|
||||||
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import BloomTokenizerFast
|
|
||||||
|
|
||||||
>>> tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom")
|
|
||||||
>>> tokenizer("Hello world")["input_ids"]
|
|
||||||
[59414, 8876]
|
|
||||||
|
|
||||||
>>> tokenizer(" Hello world")["input_ids"]
|
|
||||||
[86153, 8876]
|
|
||||||
```
|
|
||||||
|
|
||||||
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
|
|
||||||
the model was not pretrained this way, it might yield a decrease in performance.
|
|
||||||
|
|
||||||
<Tip>
|
|
||||||
|
|
||||||
When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
|
||||||
refer to this superclass for more information regarding those methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vocab_file (`str`):
|
|
||||||
Path to the vocabulary file.
|
|
||||||
merges_file (`str`):
|
|
||||||
Path to the merges file.
|
|
||||||
errors (`str`, *optional*, defaults to `"replace"`):
|
|
||||||
Paradigm to follow when decoding bytes to UTF-8. See
|
|
||||||
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
|
||||||
unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
|
||||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
|
||||||
token instead.
|
|
||||||
bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
|
||||||
The beginning of sequence token.
|
|
||||||
eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
|
||||||
The end of sequence token.
|
|
||||||
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
|
||||||
other word. (Bloom tokenizer detect beginning of words by the preceding space).
|
|
||||||
trim_offsets (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not the post-processing step should trim offsets to avoid including whitespaces.
|
|
||||||
"""
|
|
||||||
|
|
||||||
vocab_files_names = VOCAB_FILES_NAMES
|
|
||||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
|
||||||
model_input_names = ["input_ids", "attention_mask"]
|
|
||||||
slow_tokenizer_class = None
|
|
||||||
# No `max_model_input_sizes` as BLOOM uses ALiBi positional embeddings
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_file=None,
|
|
||||||
merges_file=None,
|
|
||||||
tokenizer_file=None,
|
|
||||||
unk_token="<unk>",
|
|
||||||
bos_token="<s>",
|
|
||||||
eos_token="</s>",
|
|
||||||
pad_token="<pad>",
|
|
||||||
add_prefix_space=False,
|
|
||||||
clean_up_tokenization_spaces=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
vocab_file,
|
|
||||||
merges_file,
|
|
||||||
tokenizer_file=tokenizer_file,
|
|
||||||
unk_token=unk_token,
|
|
||||||
bos_token=bos_token,
|
|
||||||
eos_token=eos_token,
|
|
||||||
pad_token=pad_token,
|
|
||||||
add_prefix_space=add_prefix_space,
|
|
||||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
# TODO @ArthurZucker this can only work one way for now, to update later-on. Tests should also properly
|
|
||||||
# check this as they were green before.
|
|
||||||
pre_tok_state = pickle.dumps(self.backend_tokenizer.pre_tokenizer)
|
|
||||||
decoder_state = pickle.dumps(self.backend_tokenizer.decoder)
|
|
||||||
|
|
||||||
if add_prefix_space:
|
|
||||||
pre_tok_state = pre_tok_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
|
|
||||||
decoder_state = decoder_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
|
|
||||||
self.backend_tokenizer.pre_tokenizer = pickle.loads(pre_tok_state)
|
|
||||||
self.backend_tokenizer.decoder = pickle.loads(decoder_state)
|
|
||||||
|
|
||||||
self.add_prefix_space = add_prefix_space
|
|
||||||
|
|
||||||
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
|
||||||
is_split_into_words = kwargs.get("is_split_into_words", False)
|
|
||||||
if not (self.add_prefix_space or not is_split_into_words):
|
|
||||||
raise Exception(
|
|
||||||
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
|
|
||||||
" pretokenized inputs."
|
|
||||||
)
|
|
||||||
|
|
||||||
return super()._batch_encode_plus(*args, **kwargs)
|
|
||||||
|
|
||||||
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
|
||||||
is_split_into_words = kwargs.get("is_split_into_words", False)
|
|
||||||
|
|
||||||
if not (self.add_prefix_space or not is_split_into_words):
|
|
||||||
raise Exception(
|
|
||||||
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
|
|
||||||
" pretokenized inputs."
|
|
||||||
)
|
|
||||||
|
|
||||||
return super()._encode_plus(*args, **kwargs)
|
|
||||||
|
|
||||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
|
||||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
|
||||||
return tuple(files)
|
|
||||||
|
|
||||||
@property
|
|
||||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
|
|
||||||
def default_chat_template(self):
|
|
||||||
"""
|
|
||||||
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
|
|
||||||
"""
|
|
||||||
logger.warning_once(
|
|
||||||
"\nNo chat template is defined for this tokenizer - using the default template "
|
|
||||||
f"for the {self.__class__.__name__} class. If the default is not appropriate for "
|
|
||||||
"your model, please set `tokenizer.chat_template` to an appropriate template. "
|
|
||||||
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
|
|
||||||
)
|
|
||||||
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
|
|
501215
seqgpt/tokenizer.json
501215
seqgpt/tokenizer.json
File diff suppressed because it is too large
Load Diff
|
@ -1,11 +0,0 @@
|
||||||
{
|
|
||||||
"add_prefix_space": false,
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"clean_up_tokenization_spaces": false,
|
|
||||||
"eos_token": "</s>",
|
|
||||||
"model_max_length": 1000000000000000019884624838656,
|
|
||||||
"pad_token": "<pad>",
|
|
||||||
"padding_side": "left",
|
|
||||||
"tokenizer_class": "BloomTokenizer",
|
|
||||||
"unk_token": "<unk>"
|
|
||||||
}
|
|
Binary file not shown.
Loading…
Reference in New Issue