add seqgpt and prompt_clue
This commit is contained in:
parent
65578680cf
commit
0dd2f2bab4
|
@ -730,7 +730,7 @@ class T5Block(nn.Module):
|
|||
return outputs
|
||||
|
||||
|
||||
class T5PreTrainedModel(TorchModel, PreTrainedModel):
|
||||
class T5PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface
|
||||
for downloading and loading pretrained models.
|
||||
|
@ -743,8 +743,7 @@ class T5PreTrainedModel(TorchModel, PreTrainedModel):
|
|||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config.name_or_path, **kwargs)
|
||||
super(Model, self).__init__(config)
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
{
|
||||
"architectures": [
|
||||
"T5ForConditionalGeneration"
|
||||
],
|
||||
"d_ff": 2048,
|
||||
"d_kv": 64,
|
||||
"d_model": 768,
|
||||
"decoder_start_token_id": 0,
|
||||
"dense_act_fn": "gelu_new",
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 12,
|
||||
"num_heads": 12,
|
||||
"num_layers": 12,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_max_distance": 128,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.26.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 32128
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"framework": "pytorch",
|
||||
"task": "text2text-generation",
|
||||
"model": {
|
||||
"type": "T5",
|
||||
"language": "zh"
|
||||
},
|
||||
"pipeline": {
|
||||
"type": "text2text-generation"
|
||||
}
|
||||
}
|
|
@ -1,21 +1,51 @@
|
|||
import torch
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
# from modelscope.models.nlp import T5ForConditionalGeneration
|
||||
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
|
||||
|
||||
from modeling_t5 import T5ForConditionalGeneration
|
||||
from modelscope.utils.config import Config
|
||||
from configuration import T5Config
|
||||
from modelscope import snapshot_download
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
"ClueAI/PromptCLUE-base-v1-5", revision="v0.1"
|
||||
seed = 4321
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
model_dir = snapshot_download("ClueAI/PromptCLUE-base-v1-5")
|
||||
|
||||
# config = T5Config()
|
||||
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
model_dir,
|
||||
return_unused_kwargs=True,
|
||||
trust_remote_code=True,
|
||||
code_revision=None,
|
||||
_commit_hash=None,
|
||||
)
|
||||
preprocessor = TextGenerationTransformersPreprocessor(model.model_dir)
|
||||
|
||||
|
||||
model = T5ForConditionalGeneration(config)
|
||||
model = model.from_pretrained(model_dir)
|
||||
|
||||
preprocessor = TextGenerationTransformersPreprocessor(model_dir)
|
||||
|
||||
out = preprocessor._tokenize_text("生成与下列文字相同意")
|
||||
|
||||
tokenizer = preprocessor.nlp_tokenizer
|
||||
response, history = model.chat(tokenizer, "生成与下列文字相同意", history=[])
|
||||
|
||||
# model = T5ForConditionalGeneration.from_pretrained(
|
||||
# "ClueAI/PromptCLUE-base-v1-5", revision="v0.1"
|
||||
# )
|
||||
|
||||
pipeline_t2t = pipeline(
|
||||
task=Tasks.text2text_generation, model=model, preprocessor=preprocessor
|
||||
)
|
||||
|
||||
|
||||
print(pipeline_t2t("生成与下列文字相同意思的句子:\n白云遍地无人扫\n答案:", do_sample=True, top_p=0.8))
|
||||
# {'text': '白云散去无踪,没人扫。'}
|
||||
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# limitations under the License.
|
||||
import copy
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -145,148 +146,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
return_dict: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||
r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary. T5 is a model
|
||||
with relative position embeddings so you should be able to pad the
|
||||
inputs on both the right and the left.
|
||||
|
||||
Indices can be obtained using [`T5Tokenizer`]. See
|
||||
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
|
||||
for detail.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
|
||||
To know more on how to prepare `input_ids` for pretraining take a
|
||||
look a [T5 Training](./t5#training).
|
||||
attention_mask (`torch.FloatTensor` of shape `(batch_size,sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask
|
||||
values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Indices of decoder input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`T5Tokenizer`]. See
|
||||
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
|
||||
for details.
|
||||
|
||||
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
||||
|
||||
T5 uses the `pad_token_id` as the starting token for
|
||||
`decoder_input_ids` generation. If `past_key_values` is used,
|
||||
optionally only the last `decoder_input_ids` have to be input (see
|
||||
`past_key_values`).
|
||||
|
||||
To know more on how to prepare `decoder_input_ids` for pretraining
|
||||
take a look at [T5 Training](./t5#training).
|
||||
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
||||
Default behavior: generate a tensor that ignores pad tokens in
|
||||
`decoder_input_ids`. Causal mask will also be used by default.
|
||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules in the
|
||||
encoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or
|
||||
`(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules in the
|
||||
decoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the cross-attention modules in
|
||||
the decoder. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
|
||||
Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*,
|
||||
`optional`: *attentions*) `last_hidden_state` of shape `(batch_size,
|
||||
sequence_length, hidden_size)` is a sequence of hidden states at the
|
||||
output of the last layer of the encoder. Used in the cross-attention
|
||||
of the decoder.
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length
|
||||
`config.n_layers` with each tuple having 4 tensors of shape
|
||||
`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
|
||||
Contains precomputed key and value hidden states of the attention
|
||||
blocks. Can be used to speed up decoding.
|
||||
|
||||
If `past_key_values` are used, the user can optionally input only
|
||||
the last `decoder_input_ids` (those that don't have their past key
|
||||
value states given to this model) of shape `(batch_size, 1)` instead
|
||||
of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to
|
||||
directly pass an embedded representation. This is useful if you want
|
||||
more control over how to convert `input_ids` indices into associated
|
||||
vectors than the model's internal embedding lookup matrix.
|
||||
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`,
|
||||
*optional*):
|
||||
Optionally, instead of passing `decoder_input_ids` you can choose to
|
||||
directly pass an embedded representation. If `past_key_values` is
|
||||
used, optionally only the last `decoder_inputs_embeds` have to be
|
||||
input (see `past_key_values`). This is useful if you want more
|
||||
control over how to convert `decoder_input_ids` indices into
|
||||
associated vectors than the model's internal embedding lookup
|
||||
matrix.
|
||||
|
||||
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset,
|
||||
`decoder_inputs_embeds` takes the value of `inputs_embeds`.
|
||||
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned
|
||||
and can be used to speed up decoding (see `past_key_values`).
|
||||
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention
|
||||
layers. See `attentions` under returned tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See
|
||||
`hidden_states` under returned tensors for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain
|
||||
tuple.
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the sequence classification/regression loss.
|
||||
Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All
|
||||
labels set to `-100` are ignored (masked), the loss is only computed
|
||||
for labels in `[0, ..., config.vocab_size]`
|
||||
|
||||
Returns:
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
|
||||
|
||||
>>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
>>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
|
||||
|
||||
>>> # training
|
||||
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
|
||||
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
|
||||
>>> outputs = model(input_ids=input_ids, labels=labels)
|
||||
>>> loss = outputs.loss
|
||||
>>> logits = outputs.logits
|
||||
|
||||
>>> # inference
|
||||
>>> input_ids = tokenizer(
|
||||
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
|
||||
>>> ).input_ids # Batch size 1
|
||||
>>> outputs = model.generate(input_ids)
|
||||
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
>>> # studies have shown that owning a dog is good for you.
|
||||
"""
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
@ -440,6 +299,82 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
|||
sequences=output if isinstance(output, torch.Tensor) else output[0]
|
||||
)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
tokenizer,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = None,
|
||||
role: str = "user",
|
||||
):
|
||||
if history is None:
|
||||
history = []
|
||||
token = tokenizer(query)
|
||||
inputs = torch.as_tensor([token["input_ids"]])
|
||||
inputs_tensor = inputs.to(next(self.parameters()).device)
|
||||
|
||||
generation_config = copy.deepcopy(self.generation_config)
|
||||
# inputs_tensor = inputs["input_ids"]
|
||||
input_ids = inputs_tensor.repeat_interleave(
|
||||
generation_config.num_return_sequences, dim=0
|
||||
)
|
||||
|
||||
outputs = self.sample(
|
||||
input_ids,
|
||||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
generation_config.output_hidden_states,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
outputs = outputs.tolist()[0][:]
|
||||
response = tokenizer.decode(outputs)
|
||||
history.append({"role": role, "content": query})
|
||||
return response, history
|
||||
|
||||
def sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
tokenizer=None,
|
||||
):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
|
||||
|
||||
isFinished = torch.zeros(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
# token_count = 0
|
||||
while True:
|
||||
input_ids_in = input_ids
|
||||
# batch_size, seq_length = input_ids_in.shape
|
||||
# position_ids_in = (
|
||||
# torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
# .unsqueeze(0)
|
||||
# .repeat(batch_size, 1)
|
||||
# )
|
||||
# model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
|
||||
|
||||
# input_ids_in = self.prepare_inputs_for_generation(input_ids)
|
||||
probs, next_tokens = self(input_ids)
|
||||
# **model_inputs,
|
||||
# output_hidden_states=output_hidden_states,
|
||||
# tokenizer=tokenizer,
|
||||
# )
|
||||
|
||||
# finished sentences should add a padding token to next
|
||||
pad_token = pad_token_id * isFinished
|
||||
next_tokens = next_tokens * (1 - isFinished) + pad_token
|
||||
|
||||
isFinished = isFinished | next_tokens.eq(eos_token_id_tensor)
|
||||
if isFinished.min() == 1: # all batch is finish
|
||||
break
|
||||
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
return input_ids
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
# if decoder past is not included in output
|
||||
# speedy decoding is disabled and no need to reorder
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"_name_or_path": "./bloomz_560m_pretrained",
|
||||
"apply_residual_connection_post_layernorm": false,
|
||||
"architectures": [
|
||||
"BloomForCausalLM"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"attention_softmax_in_fp32": true,
|
||||
"bias_dropout_fusion": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_dropout": 0.0,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"masked_softmax_fusion": true,
|
||||
"model_type": "bloom",
|
||||
"n_head": 16,
|
||||
"n_inner": null,
|
||||
"n_layer": 24,
|
||||
"offset_alibi": 100,
|
||||
"pad_token_id": 3,
|
||||
"pretraining_tp": 1,
|
||||
"seq_length": 2048,
|
||||
"skip_bias_add": true,
|
||||
"skip_bias_add_qkv": false,
|
||||
"slow_but_exact": false,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.28.1",
|
||||
"unk_token_id": 0,
|
||||
"use_cache": true,
|
||||
"vocab_size": 250880
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"framework": "pytorch",
|
||||
"task": "text-generation",
|
||||
"model": {
|
||||
"type": "bloom"
|
||||
},
|
||||
"pipeline": {
|
||||
"type": "seqgpt"
|
||||
},
|
||||
"allow_remote": true
|
||||
}
|
|
@ -0,0 +1,242 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Bloom configuration"""
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
|
||||
|
||||
from packaging import version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ... import PreTrainedTokenizer, TensorType
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.onnx import OnnxConfigWithPast, PatchingSpec
|
||||
from transformers.utils import is_torch_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"bigscience/bloom": "https://huggingface.co/bigscience/bloom/resolve/main/config.json",
|
||||
"bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/config.json",
|
||||
"bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/config.json",
|
||||
"bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/config.json",
|
||||
"bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/config.json",
|
||||
"bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class BloomConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`BloomModel`]. It is used to instantiate a Bloom
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to the Bloom architecture
|
||||
[bigscience/bloom](https://huggingface.co/bigscience/bloom).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 250880):
|
||||
Vocabulary size of the Bloom model. Defines the maximum number of different tokens that can be represented
|
||||
by the `inputs_ids` passed when calling [`BloomModel`]. Check [this
|
||||
discussion](https://huggingface.co/bigscience/bloom/discussions/120#633d28389addb8530b406c2a) on how the
|
||||
`vocab_size` has been defined.
|
||||
hidden_size (`int`, *optional*, defaults to 64):
|
||||
Dimensionality of the embeddings and hidden states.
|
||||
n_layer (`int`, *optional*, defaults to 2):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
n_head (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon to use in the layer normalization layers.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`):
|
||||
If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
|
||||
hidden_dropout (`float`, *optional*, defaults to 0.1):
|
||||
Dropout rate of the dropout function on the bias dropout.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.1):
|
||||
Dropout rate applied to the attention probs
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
pretraining_tp (`int`, *optional*, defaults to `1`):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
||||
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
||||
issue](https://github.com/pytorch/pytorch/issues/76232). Note also that this is enabled only when
|
||||
`slow_but_exact=True`.
|
||||
slow_but_exact (`bool`, *optional*, defaults to `False`):
|
||||
Experimental feature. Whether to use slow but exact implementation of the attention mechanism. While
|
||||
merging the TP rank tensors, due to slicing operations the results may be slightly different between the
|
||||
model trained on Megatron and our model. Please refer to [this
|
||||
issue](https://github.com/pytorch/pytorch/issues/76232). A solution to obtain more accurate results is to
|
||||
enable this feature. Enabling this will hurt the computational time of the inference. Will be probably
|
||||
resolved in the future once the main model has been fine-tuned with TP_rank=1.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import BloomConfig, BloomModel
|
||||
|
||||
>>> # Initializing a Bloom configuration
|
||||
>>> configuration = BloomConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights) from the configuration
|
||||
>>> model = BloomModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "bloom"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {
|
||||
"num_hidden_layers": "n_layer",
|
||||
"num_attention_heads": "n_head",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=250880,
|
||||
hidden_size=64,
|
||||
n_layer=2,
|
||||
n_head=8,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
use_cache=True,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
pretraining_tp=1, # TP rank used when training with megatron
|
||||
slow_but_exact=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
# Backward compatibility with n_embed kwarg
|
||||
n_embed = kwargs.pop("n_embed", None)
|
||||
self.hidden_size = hidden_size if n_embed is None else n_embed
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.use_cache = use_cache
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.slow_but_exact = slow_but_exact
|
||||
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
|
||||
class BloomOnnxConfig(OnnxConfigWithPast):
|
||||
torch_onnx_minimum_version = version.parse("1.12")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
task: str = "default",
|
||||
patching_specs: List[PatchingSpec] = None,
|
||||
use_past: bool = False,
|
||||
):
|
||||
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
|
||||
if not getattr(self._config, "pad_token_id", None):
|
||||
# TODO: how to do that better?
|
||||
self._config.pad_token_id = 0
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
|
||||
if self.use_past:
|
||||
# BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self._config.n_layer
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self._config.n_head
|
||||
|
||||
@property
|
||||
def atol_for_validation(self) -> float:
|
||||
return 1e-3
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional["TensorType"] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||
)
|
||||
|
||||
# We need to order the input in the way they appears in the forward()
|
||||
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||
|
||||
# Need to add the past_keys
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
head_dim = self._config.hidden_size // self.num_attention_heads
|
||||
past_key_shape = (
|
||||
batch * self.num_attention_heads,
|
||||
head_dim,
|
||||
past_key_values_length,
|
||||
)
|
||||
past_value_shape = (
|
||||
batch * self.num_attention_heads,
|
||||
past_key_values_length,
|
||||
head_dim,
|
||||
)
|
||||
ordered_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
|
||||
]
|
||||
|
||||
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||
if self.use_past:
|
||||
mask_dtype = ordered_inputs["attention_mask"].dtype
|
||||
ordered_inputs["attention_mask"] = torch.cat(
|
||||
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
|
||||
)
|
||||
|
||||
return ordered_inputs
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 13
|
|
@ -0,0 +1,255 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert BigScience BLOOM checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BloomConfig, BloomModel
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
WEIGHTS_TO_AVERAGE_ENDSWITH = [
|
||||
"word_embeddings_layernorm.weight",
|
||||
"word_embeddings_layernorm.bias",
|
||||
"input_layernorm.weight",
|
||||
"input_layernorm.bias",
|
||||
"post_attention_layernorm.weight",
|
||||
"post_attention_layernorm.bias",
|
||||
"self_attention.dense.bias",
|
||||
"mlp.dense_4h_to_h.bias",
|
||||
"ln_f.weight",
|
||||
"ln_f.bias",
|
||||
]
|
||||
|
||||
WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN = [
|
||||
"mlp.dense_4h_to_h.weight",
|
||||
"self_attention.dense.weight",
|
||||
]
|
||||
|
||||
|
||||
def layer_name_mapping(key, file):
|
||||
"""Convert Megatron-DeepSpeed TP/PP weights mapping in transformers PP only"""
|
||||
# Handle first and last layers
|
||||
layer_rename_map = {
|
||||
"word_embeddings.weight": "word_embeddings.weight",
|
||||
"word_embeddings.norm.weight": "word_embeddings_layernorm.weight",
|
||||
"word_embeddings.norm.bias": "word_embeddings_layernorm.bias",
|
||||
"weight": "ln_f.weight",
|
||||
"bias": "ln_f.bias",
|
||||
}
|
||||
|
||||
if key in layer_rename_map:
|
||||
return layer_rename_map[key]
|
||||
|
||||
# Handle transformer blocks
|
||||
layer_number = int(re.match(r".*layer_(\d*).*", file)[1])
|
||||
layer_number -= 3
|
||||
return f"h.{layer_number}." + key
|
||||
|
||||
|
||||
def get_dtype_size(dtype):
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def convert_bloom_checkpoint_to_pytorch(
|
||||
bloom_checkpoint_path, bloom_config_file, pytorch_dump_folder_path, shard_model, pretraining_tp
|
||||
):
|
||||
# Construct model
|
||||
if bloom_config_file == "":
|
||||
config = BloomConfig()
|
||||
else:
|
||||
config = BloomConfig.from_json_file(bloom_config_file)
|
||||
|
||||
if shard_model:
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
index_dict = {"weight_map": {}, "metadata": {}}
|
||||
total_size = 0
|
||||
|
||||
missing_keys = None
|
||||
|
||||
config = BloomConfig()
|
||||
|
||||
for j, file in enumerate(file_names):
|
||||
print("Processing file: {}".format(file))
|
||||
tensors = None
|
||||
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
torch.save(
|
||||
tensors,
|
||||
os.path.join(
|
||||
pytorch_dump_folder_path,
|
||||
"pytorch_model_{}-of-{}.bin".format(str(j + 1).zfill(5), str(len(file_names)).zfill(5)),
|
||||
),
|
||||
)
|
||||
|
||||
for key in tensors.keys():
|
||||
value = tensors[key]
|
||||
total_size += value.numel() * get_dtype_size(value.dtype)
|
||||
if key not in index_dict["weight_map"]:
|
||||
index_dict["weight_map"][key] = "pytorch_model_{}-of-{}.bin".format(
|
||||
str(j + 1).zfill(5), str(len(file_names)).zfill(5)
|
||||
)
|
||||
|
||||
config = BloomConfig()
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
index_dict["metadata"]["total_size"] = total_size
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
with open(os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME + ".index.json"), "w", encoding="utf-8") as f:
|
||||
json_config = json.dumps(index_dict, indent=2, sort_keys=True) + "\n"
|
||||
f.write(json_config)
|
||||
else:
|
||||
model = BloomModel(config)
|
||||
|
||||
file_names = os.listdir(bloom_checkpoint_path)
|
||||
file_names = sorted(filter(lambda s: s.startswith("layer") and "model_00" in s, file_names))
|
||||
|
||||
missing_keys = None
|
||||
for i, file in enumerate(file_names):
|
||||
tensors = None
|
||||
for i in range(pretraining_tp):
|
||||
# load all TP files
|
||||
f_name = file.replace("model_00", f"model_0{i}")
|
||||
temp = torch.load(os.path.join(bloom_checkpoint_path, f_name), map_location="cpu")
|
||||
|
||||
# Rename keys in the transformers names
|
||||
keys = list(temp.keys())
|
||||
for key in keys:
|
||||
temp[layer_name_mapping(key, file)] = temp.pop(key)
|
||||
|
||||
if tensors is None:
|
||||
tensors = temp
|
||||
else:
|
||||
for key in tensors.keys():
|
||||
# We average (sum and then divide) some weights accross TP ranks (see https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/olruwase/sync_layer_norms/megatron/training.py#L425)
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] += temp[key]
|
||||
else:
|
||||
# Some weights are RowParallelLinear in Megatron-Deepspeed, others are ColumnParallel
|
||||
cat_dim = 1 if any(text in key for text in WEIGHTS_WITH_ROW_PARALLELISM_CONTAIN) else 0
|
||||
# We concatenate these weights accross TP ranks
|
||||
tensors[key] = torch.cat([tensors[key], temp[key]], dim=cat_dim)
|
||||
|
||||
# Divide by the number of TP the weights we want to average
|
||||
for key in tensors.keys():
|
||||
if any(key.endswith(end) for end in WEIGHTS_TO_AVERAGE_ENDSWITH):
|
||||
tensors[key] = tensors[key] / pretraining_tp
|
||||
|
||||
other_keys = model.load_state_dict(tensors, strict=False)
|
||||
assert not other_keys.unexpected_keys, f"The keys {other_keys.unexpected_keys} are unexpected"
|
||||
if missing_keys is None:
|
||||
missing_keys = set(other_keys.missing_keys)
|
||||
else:
|
||||
missing_keys = missing_keys.intersection(set(other_keys.missing_keys))
|
||||
|
||||
assert not missing_keys, f"The keys {missing_keys} are missing"
|
||||
|
||||
# Save pytorch-model
|
||||
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
|
||||
pytorch_weights_dump_path = pytorch_dump_folder_path + "/" + WEIGHTS_NAME
|
||||
pytorch_config_dump_path = pytorch_dump_folder_path + "/" + CONFIG_NAME
|
||||
print(f"Save PyTorch model to {pytorch_weights_dump_path} with dtype {config.torch_dtype}")
|
||||
if config.torch_dtype is not None:
|
||||
model = model.to(config.torch_dtype)
|
||||
torch.save(model.state_dict(), pytorch_weights_dump_path)
|
||||
print(f"Save configuration file to {pytorch_config_dump_path}")
|
||||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
|
||||
f.write(config.to_json_string())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--bloom_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the Megatron-LM checkpoint path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bloom_config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help=(
|
||||
"An optional config json file corresponding to the pre-trained model. \n"
|
||||
"This specifies the model architecture."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard_model",
|
||||
action="store_true",
|
||||
help="An optional setting to shard the output model \nThis enables sharding the converted checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretraining_tp",
|
||||
default=4,
|
||||
type=int,
|
||||
help="Pretraining TP rank that has been used when training the model in Megatron-LM \n",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_bloom_checkpoint_to_pytorch(
|
||||
args.bloom_checkpoint_path,
|
||||
args.bloom_config_file,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.shard_model,
|
||||
args.pretraining_tp,
|
||||
)
|
|
@ -0,0 +1,69 @@
|
|||
# from modelscope.utils.constant import Tasks
|
||||
# from modelscope.pipelines import pipeline
|
||||
|
||||
# prompt = "输入: 中国的首都在哪里\n输出: "
|
||||
|
||||
# # task可选值为 抽取、分类。text为需要分析的文本。labels为类型列表,中文逗号分隔。
|
||||
# inputs = {'task': '抽取', 'text': '杭州欢迎你。', 'labels': '地名'}
|
||||
# # PROMPT_TEMPLATE保持不变
|
||||
# PROMPT_TEMPLATE = '输入: {text}\n{task}: {labels}\n输出: '
|
||||
# # prompt = PROMPT_TEMPLATE.format(**inputs)
|
||||
# pipeline_ins = pipeline(task=Tasks.text_generation, model='damo/nlp_seqgpt-560m', model_revision = 'v1.0.1', run_kwargs={'gen_token': '[GEN]'})
|
||||
# print(pipeline_ins(prompt))
|
||||
|
||||
import torch
|
||||
import json
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
from modelscope.preprocessors import TextGenerationTransformersPreprocessor
|
||||
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope import snapshot_download
|
||||
from transformers import AutoConfig
|
||||
|
||||
from modeling_bloom import BloomForCausalLM
|
||||
from tokenization_bloom_fast import BloomTokenizerFast
|
||||
|
||||
seed = 4321
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
model_dir = snapshot_download("damo/nlp_seqgpt-560m")
|
||||
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
model_dir,
|
||||
return_unused_kwargs=True,
|
||||
trust_remote_code=True,
|
||||
code_revision=None,
|
||||
_commit_hash=None,
|
||||
)
|
||||
|
||||
tokenizer_config_file = "./tokenizer_config.json"
|
||||
if tokenizer_config_file is not None:
|
||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
||||
init_kwargs = json.load(tokenizer_config_handle)
|
||||
init_kwargs.pop("tokenizer_class", None)
|
||||
init_kwargs.pop("tokenizer_file", None)
|
||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||
init_inputs = saved_init_inputs
|
||||
init_kwargs["vocab_file"] = None
|
||||
init_kwargs["added_tokens_file"] = None
|
||||
init_kwargs["special_tokens_map_file"] = "./special_tokens_map.json"
|
||||
init_kwargs["tokenizer_file"] = "./tokenizer.json"
|
||||
init_kwargs["name_or_path"] = model_dir
|
||||
tokenizer = BloomTokenizerFast(*init_inputs, **init_kwargs)
|
||||
|
||||
model = BloomForCausalLM(config)
|
||||
model = model.from_pretrained(model_dir).cuda().eval()
|
||||
|
||||
prompt = "输入: 中国的首都在哪里\n输出: "
|
||||
prompt = "输入: 美国的首都在哪里\n输出: "
|
||||
|
||||
|
||||
|
||||
input_ids = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
|
||||
input_ids = input_ids.input_ids.cuda()
|
||||
outputs = model.generate(input_ids, num_beams=4, do_sample=False, max_new_tokens=256)
|
||||
decoded_sentences = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
print(decoded_sentences[0])
|
|
@ -0,0 +1,7 @@
|
|||
{
|
||||
"_from_model_config": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"pad_token_id": 3,
|
||||
"transformers_version": "4.28.1"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,734 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc. Team and Bigscience Workshop. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Flax BLOOM model."""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
||||
from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask
|
||||
from flax.linen.activation import tanh
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from jax import lax
|
||||
|
||||
from ...modeling_flax_outputs import (
|
||||
FlaxBaseModelOutput,
|
||||
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
||||
FlaxCausalLMOutput,
|
||||
)
|
||||
from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||
from .configuration_bloom import BloomConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "bigscience/bloom"
|
||||
_CONFIG_FOR_DOC = "BloomConfig"
|
||||
|
||||
|
||||
BLOOM_START_DOCSTRING = r"""
|
||||
|
||||
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a Flax Linen
|
||||
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
||||
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
||||
|
||||
Finally, this model supports inherent JAX features such as:
|
||||
|
||||
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
||||
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
||||
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
||||
|
||||
Parameters:
|
||||
config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
||||
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
||||
`jax.numpy.bfloat16` (on TPUs).
|
||||
|
||||
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
|
||||
specified all the computation will be performed with the given `dtype`.
|
||||
|
||||
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
||||
parameters.**
|
||||
|
||||
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
|
||||
[`~FlaxPreTrainedModel.to_bf16`].
|
||||
"""
|
||||
|
||||
BLOOM_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
|
||||
`input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
|
||||
|
||||
Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
||||
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
||||
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
def build_alibi_tensor(attention_mask: jnp.ndarray, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32):
|
||||
"""
|
||||
Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
`softmax(l+a) = softmax(l)`. Based on
|
||||
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
||||
Link to paper: https://arxiv.org/abs/2108.12409
|
||||
|
||||
Args:
|
||||
attention_mask (`jnp.ndarray`):
|
||||
Token-wise attention mask, this should be of shape `(batch_size, max_seq_len)`.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
||||
The data type (dtype) of the output tensor.
|
||||
|
||||
Returns: Alibi tensor of shape `(batch_size * num_heads, 1, max_seq_len)`.
|
||||
"""
|
||||
batch_size, seq_length = attention_mask.shape
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = jnp.array(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=jnp.float32)
|
||||
powers = jnp.arange(1, 1 + closest_power_of_2, dtype=jnp.float32)
|
||||
slopes = jax.lax.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = jnp.array(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=jnp.float32)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32)
|
||||
slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0)
|
||||
|
||||
# Note: the Alibi tensor will added to the attention bias that will be applied to the query, key product of attention
|
||||
# therefore, Alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
||||
# so that the query_length dimension will then be broadcast correctly.
|
||||
# This is more or less identical to T5's relative position bias:
|
||||
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
|
||||
arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :]
|
||||
alibi = slopes[..., None] * arange_tensor
|
||||
alibi = jnp.expand_dims(alibi, axis=2)
|
||||
return jnp.asarray(alibi, dtype)
|
||||
|
||||
|
||||
class FlaxBloomAttention(nn.Module):
|
||||
config: BloomConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.hidden_size = self.config.hidden_size
|
||||
self.num_heads = self.config.n_head
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
|
||||
|
||||
if self.head_dim * self.num_heads != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and "
|
||||
f"`num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
dense = partial(
|
||||
nn.Dense,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
||||
)
|
||||
|
||||
self.query_key_value = dense(self.hidden_size * 3)
|
||||
self.dense = dense(self.hidden_size)
|
||||
self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout)
|
||||
|
||||
def _split_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, self.head_dim * 3))
|
||||
|
||||
def _merge_heads(self, hidden_states):
|
||||
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
|
||||
|
||||
@nn.compact
|
||||
# Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache
|
||||
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
||||
"""
|
||||
This function takes projected key, value states from a single input token and concatenates the states to cached
|
||||
states from previous steps. This function is slighly adapted from the official Flax repository:
|
||||
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
|
||||
"""
|
||||
# detect if we're initializing by absence of existing cache data.
|
||||
is_initialized = self.has_variable("cache", "cached_key")
|
||||
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
|
||||
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
|
||||
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
||||
|
||||
if is_initialized:
|
||||
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
|
||||
# update key, value caches with our new 1d spatial slices
|
||||
cur_index = cache_index.value
|
||||
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
|
||||
key = lax.dynamic_update_slice(cached_key.value, key, indices)
|
||||
value = lax.dynamic_update_slice(cached_value.value, value, indices)
|
||||
cached_key.value = key
|
||||
cached_value.value = value
|
||||
num_updated_cache_vectors = query.shape[1]
|
||||
cache_index.value = cache_index.value + num_updated_cache_vectors
|
||||
# causal mask for cached decoder self-attention: our single query position should only attend to those key
|
||||
# positions that have already been generated and cached, not the remaining zero elements.
|
||||
pad_mask = jnp.broadcast_to(
|
||||
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
||||
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
||||
)
|
||||
attention_mask = combine_masks(pad_mask, attention_mask)
|
||||
return key, value, attention_mask
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
alibi,
|
||||
attention_mask=None,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
# proj q, k, v
|
||||
fused_qkv = self.query_key_value(hidden_states)
|
||||
fused_qkv = self._split_heads(fused_qkv)
|
||||
query, key, value = jnp.split(fused_qkv, 3, axis=-1)
|
||||
|
||||
causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
|
||||
|
||||
# for fast decoding causal attention mask should be shifted
|
||||
causal_attention_mask_shift = (
|
||||
self.variables["cache"]["cache_index"] if self.has_variable("cache", "cached_key") else 0
|
||||
)
|
||||
|
||||
# fast decoding for generate requires special attention_mask
|
||||
if self.has_variable("cache", "cached_key"):
|
||||
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
||||
causal_attention_mask = jax.lax.dynamic_slice(
|
||||
causal_attention_mask,
|
||||
(0, 0, causal_attention_mask_shift, 0),
|
||||
(1, 1, seq_length, max_decoder_length),
|
||||
)
|
||||
|
||||
# broadcast causal attention mask & attention mask to fit for merge
|
||||
causal_attention_mask = jnp.broadcast_to(
|
||||
causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
|
||||
)
|
||||
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape)
|
||||
attention_mask = combine_masks(attention_mask, causal_attention_mask)
|
||||
|
||||
dropout_rng = None
|
||||
if not deterministic and self.config.attention_dropout > 0.0:
|
||||
dropout_rng = self.make_rng("dropout")
|
||||
|
||||
# During fast autoregressive decoding, we feed one position at a time,
|
||||
# and cache the keys and values step by step.
|
||||
if self.has_variable("cache", "cached_key") or init_cache:
|
||||
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
|
||||
|
||||
# transform boolean mask into float mask
|
||||
mask_value = jnp.finfo(self.dtype).min
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
|
||||
)
|
||||
|
||||
attention_bias = attention_bias + alibi
|
||||
|
||||
# Cast in fp32 if the original dtype is different from fp32
|
||||
attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
|
||||
|
||||
attn_weights = dot_product_attention_weights(
|
||||
query,
|
||||
key,
|
||||
bias=attention_bias,
|
||||
dropout_rng=dropout_rng,
|
||||
dropout_rate=self.config.attention_dropout,
|
||||
deterministic=deterministic,
|
||||
dtype=attention_dtype,
|
||||
)
|
||||
|
||||
# Cast back in the original dtype if the native dtype is not fp32
|
||||
if self.attention_softmax_in_fp32:
|
||||
attn_weights = attn_weights.astype(self.dtype)
|
||||
|
||||
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
|
||||
attn_output = self._merge_heads(attn_output)
|
||||
attn_output = self.dense(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
||||
|
||||
attn_output = attn_output + residual
|
||||
|
||||
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
||||
return outputs
|
||||
|
||||
|
||||
class BloomGELU(nn.Module):
|
||||
def setup(self):
|
||||
self.dtype = jnp.float32
|
||||
|
||||
def __call__(self, x):
|
||||
return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
||||
|
||||
|
||||
class FlaxBloomMLP(nn.Module):
|
||||
config: BloomConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
hidden_size = self.config.hidden_size
|
||||
|
||||
kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
|
||||
|
||||
self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init)
|
||||
self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init)
|
||||
self.hidden_dropout = nn.Dropout(self.config.hidden_dropout)
|
||||
self.act = BloomGELU()
|
||||
|
||||
def __call__(self, hidden_states, residual, deterministic: bool = True):
|
||||
hidden_states = self.dense_h_to_4h(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
|
||||
intermediate_output = self.dense_4h_to_h(hidden_states)
|
||||
|
||||
intermediate_output = intermediate_output + residual
|
||||
hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxBloomBlock(nn.Module):
|
||||
config: BloomConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
|
||||
self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype)
|
||||
self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
|
||||
self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype)
|
||||
|
||||
self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm
|
||||
self.hidden_dropout = self.config.hidden_dropout
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
alibi,
|
||||
attention_mask=None,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
|
||||
# layer norm before saving residual if config calls for it
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# self-attention
|
||||
attn_outputs = self.self_attention(
|
||||
layernorm_output,
|
||||
residual=residual,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
attention_output = attn_outputs[0]
|
||||
|
||||
outputs = attn_outputs[1:]
|
||||
|
||||
post_layernorm = self.post_attention_layernorm(attention_output)
|
||||
|
||||
# set residual based on config
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = post_layernorm
|
||||
else:
|
||||
residual = attention_output
|
||||
|
||||
output = self.mlp(post_layernorm, residual, deterministic=deterministic)
|
||||
|
||||
outputs = (output,) + outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxBloomPreTrainedModel(FlaxPreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = BloomConfig
|
||||
base_model_prefix = "transformer"
|
||||
module_class: nn.Module = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
input_shape: Tuple = (1, 1),
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
_do_init: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensors
|
||||
input_ids = jnp.zeros(input_shape, dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
params_rng, dropout_rng = jax.random.split(rng)
|
||||
rngs = {"params": params_rng, "dropout": dropout_rng}
|
||||
|
||||
random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
|
||||
|
||||
if params is not None:
|
||||
random_params = flatten_dict(unfreeze(random_params))
|
||||
params = flatten_dict(unfreeze(params))
|
||||
for missing_key in self._missing_keys:
|
||||
params[missing_key] = random_params[missing_key]
|
||||
self._missing_keys = set()
|
||||
return freeze(unflatten_dict(params))
|
||||
else:
|
||||
return random_params
|
||||
|
||||
def init_cache(self, batch_size, max_length):
|
||||
r"""
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
||||
max_length (`int`):
|
||||
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
||||
cache.
|
||||
"""
|
||||
# init input variables to retrieve cache
|
||||
input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
attention_mask = jnp.ones_like(input_ids)
|
||||
|
||||
init_variables = self.module.init(
|
||||
jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True
|
||||
)
|
||||
return unfreeze(init_variables["cache"])
|
||||
|
||||
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
past_key_values: dict = None,
|
||||
params: dict = None,
|
||||
dropout_rng: jax.random.PRNGKey = None,
|
||||
train: bool = False,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
batch_size, sequence_length = input_ids.shape
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = jnp.ones((batch_size, sequence_length))
|
||||
|
||||
# Handle any PRNG if needed
|
||||
rngs = {}
|
||||
if dropout_rng is not None:
|
||||
rngs["dropout"] = dropout_rng
|
||||
|
||||
inputs = {"params": params or self.params}
|
||||
|
||||
# If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
|
||||
# down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
|
||||
# changed by FlaxBloomAttention module
|
||||
if past_key_values:
|
||||
inputs["cache"] = past_key_values
|
||||
mutable = ["cache"]
|
||||
else:
|
||||
mutable = False
|
||||
|
||||
outputs = self.module.apply(
|
||||
inputs,
|
||||
jnp.array(input_ids, dtype="i4"),
|
||||
jnp.array(attention_mask, dtype="i4"),
|
||||
not train,
|
||||
False,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
rngs=rngs,
|
||||
mutable=mutable,
|
||||
)
|
||||
|
||||
# add updated cache to model output
|
||||
if past_key_values is not None and return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs["past_key_values"] = unfreeze(past_key_values["cache"])
|
||||
return outputs
|
||||
elif past_key_values is not None and not return_dict:
|
||||
outputs, past_key_values = outputs
|
||||
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxBloomBlockCollection(nn.Module):
|
||||
config: BloomConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.layers = [
|
||||
FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype)
|
||||
for layer_number in range(self.config.num_hidden_layers)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states,
|
||||
alibi,
|
||||
attention_mask=None,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
):
|
||||
all_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
|
||||
for layer_number in range(self.config.num_hidden_layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
layer_outputs = self.layers[layer_number](
|
||||
hidden_states,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_attentions += (layer_outputs[1],)
|
||||
|
||||
# this contains possible `None` values - `FlaxBloomModule` will filter them out
|
||||
outputs = (hidden_states, all_hidden_states, all_attentions)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class FlaxBloomModule(nn.Module):
|
||||
config: BloomConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.embed_dim = self.config.hidden_size
|
||||
|
||||
# word embeddings (no positional embedding layer)
|
||||
self.word_embeddings = nn.Embed(
|
||||
self.config.vocab_size,
|
||||
self.embed_dim,
|
||||
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# post-embedding layernorm
|
||||
self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
|
||||
# transformer layers
|
||||
self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype)
|
||||
|
||||
# final layernorm
|
||||
self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
deterministic=True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
# do post-embedding layernorm
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
|
||||
# build alibi depending on `attention_mask`
|
||||
alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
|
||||
|
||||
outputs = self.h(
|
||||
hidden_states,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_hidden_states=output_hidden_states,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = outputs[1] + (hidden_states,)
|
||||
outputs = (hidden_states, all_hidden_states) + outputs[2:]
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [outputs[0], outputs[-1]] if v is not None)
|
||||
|
||||
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
hidden_states=outputs[1],
|
||||
attentions=outputs[-1],
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
BLOOM_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom
|
||||
class FlaxBloomModel(FlaxBloomPreTrainedModel):
|
||||
module_class = FlaxBloomModule
|
||||
|
||||
|
||||
append_call_sample_docstring(FlaxBloomModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
|
||||
|
||||
|
||||
class FlaxBloomForCausalLMModule(nn.Module):
|
||||
config: BloomConfig
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
self.transformer = FlaxBloomModule(self.config, dtype=self.dtype)
|
||||
self.lm_head = nn.Dense(
|
||||
self.config.vocab_size,
|
||||
use_bias=False,
|
||||
dtype=self.dtype,
|
||||
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
deterministic: bool = True,
|
||||
init_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
outputs = self.transformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
deterministic=deterministic,
|
||||
init_cache=init_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if self.config.tie_word_embeddings:
|
||||
shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T
|
||||
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
|
||||
else:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (lm_logits,) + outputs[1:]
|
||||
|
||||
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
||||
embeddings).
|
||||
""",
|
||||
BLOOM_START_DOCSTRING,
|
||||
)
|
||||
class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel):
|
||||
module_class = FlaxBloomForCausalLMModule
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
|
||||
# initializing the cache
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
past_key_values = self.init_cache(batch_size, max_length)
|
||||
# Note that usually one would have to put 0's in the attention_mask for
|
||||
# x > input_ids.shape[-1] and x < cache_length. But since Bloom uses a causal mask,
|
||||
# those positions are masked anyway. Thus, we can create a single static attention_mask here,
|
||||
# which is more efficient for compilation
|
||||
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
|
||||
|
||||
return {
|
||||
"past_key_values": past_key_values,
|
||||
"attention_mask": extended_attention_mask,
|
||||
}
|
||||
|
||||
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
||||
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
||||
return model_kwargs
|
||||
|
||||
|
||||
append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC)
|
|
@ -0,0 +1,9 @@
|
|||
{
|
||||
"additional_special_tokens": [
|
||||
"[GEN]"
|
||||
],
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<pad>",
|
||||
"unk_token": "<unk>"
|
||||
}
|
|
@ -0,0 +1,177 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for Bloom."""
|
||||
|
||||
|
||||
import pickle
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {"tokenizer_file": "tokenizer.json"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"tokenizer_file": {
|
||||
"bigscience/tokenizer": "https://huggingface.co/bigscience/tokenizer/blob/main/tokenizer.json",
|
||||
"bigscience/bloom-560m": "https://huggingface.co/bigscience/bloom-560m/blob/main/tokenizer.json",
|
||||
"bigscience/bloom-1b1": "https://huggingface.co/bigscience/bloom-1b1/blob/main/tokenizer.json",
|
||||
"bigscience/bloom-1b7": "https://huggingface.co/bigscience/bloom-1b7/blob/main/tokenizer.json",
|
||||
"bigscience/bloom-3b": "https://huggingface.co/bigscience/bloom-3b/blob/main/tokenizer.json",
|
||||
"bigscience/bloom-7b1": "https://huggingface.co/bigscience/bloom-7b1/blob/main/tokenizer.json",
|
||||
"bigscience/bloom": "https://huggingface.co/bigscience/bloom/blob/main/tokenizer.json",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class BloomTokenizerFast(PreTrainedTokenizerFast):
|
||||
"""
|
||||
Construct a "fast" Bloom tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
|
||||
Byte-Pair-Encoding.
|
||||
|
||||
This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will
|
||||
be encoded differently whether it is at the beginning of the sentence (without space) or not:
|
||||
|
||||
```python
|
||||
>>> from transformers import BloomTokenizerFast
|
||||
|
||||
>>> tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom")
|
||||
>>> tokenizer("Hello world")["input_ids"]
|
||||
[59414, 8876]
|
||||
|
||||
>>> tokenizer(" Hello world")["input_ids"]
|
||||
[86153, 8876]
|
||||
```
|
||||
|
||||
You can get around that behavior by passing `add_prefix_space=True` when instantiating this tokenizer, but since
|
||||
the model was not pretrained this way, it might yield a decrease in performance.
|
||||
|
||||
<Tip>
|
||||
|
||||
When used with `is_split_into_words=True`, this tokenizer needs to be instantiated with `add_prefix_space=True`.
|
||||
|
||||
</Tip>
|
||||
|
||||
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
|
||||
refer to this superclass for more information regarding those methods.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
merges_file (`str`):
|
||||
Path to the merges file.
|
||||
errors (`str`, *optional*, defaults to `"replace"`):
|
||||
Paradigm to follow when decoding bytes to UTF-8. See
|
||||
[bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
|
||||
unk_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
||||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
||||
token instead.
|
||||
bos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
||||
The beginning of sequence token.
|
||||
eos_token (`str`, *optional*, defaults to `<|endoftext|>`):
|
||||
The end of sequence token.
|
||||
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add an initial space to the input. This allows to treat the leading word just as any
|
||||
other word. (Bloom tokenizer detect beginning of words by the preceding space).
|
||||
trim_offsets (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the post-processing step should trim offsets to avoid including whitespaces.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
slow_tokenizer_class = None
|
||||
# No `max_model_input_sizes` as BLOOM uses ALiBi positional embeddings
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file=None,
|
||||
merges_file=None,
|
||||
tokenizer_file=None,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
pad_token="<pad>",
|
||||
add_prefix_space=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
vocab_file,
|
||||
merges_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
# TODO @ArthurZucker this can only work one way for now, to update later-on. Tests should also properly
|
||||
# check this as they were green before.
|
||||
pre_tok_state = pickle.dumps(self.backend_tokenizer.pre_tokenizer)
|
||||
decoder_state = pickle.dumps(self.backend_tokenizer.decoder)
|
||||
|
||||
if add_prefix_space:
|
||||
pre_tok_state = pre_tok_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
|
||||
decoder_state = decoder_state.replace(b'"add_prefix_space":false', b'"add_prefix_space": true')
|
||||
self.backend_tokenizer.pre_tokenizer = pickle.loads(pre_tok_state)
|
||||
self.backend_tokenizer.decoder = pickle.loads(decoder_state)
|
||||
|
||||
self.add_prefix_space = add_prefix_space
|
||||
|
||||
def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
||||
is_split_into_words = kwargs.get("is_split_into_words", False)
|
||||
if not (self.add_prefix_space or not is_split_into_words):
|
||||
raise Exception(
|
||||
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
|
||||
" pretokenized inputs."
|
||||
)
|
||||
|
||||
return super()._batch_encode_plus(*args, **kwargs)
|
||||
|
||||
def _encode_plus(self, *args, **kwargs) -> BatchEncoding:
|
||||
is_split_into_words = kwargs.get("is_split_into_words", False)
|
||||
|
||||
if not (self.add_prefix_space or not is_split_into_words):
|
||||
raise Exception(
|
||||
f"You need to instantiate {self.__class__.__name__} with add_prefix_space=True to use it with"
|
||||
" pretokenized inputs."
|
||||
)
|
||||
|
||||
return super()._encode_plus(*args, **kwargs)
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
files = self._tokenizer.model.save(save_directory, name=filename_prefix)
|
||||
return tuple(files)
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.default_chat_template
|
||||
def default_chat_template(self):
|
||||
"""
|
||||
A simple chat template that ignores role information and just concatenates messages with EOS tokens.
|
||||
"""
|
||||
logger.warning_once(
|
||||
"\nNo chat template is defined for this tokenizer - using the default template "
|
||||
f"for the {self.__class__.__name__} class. If the default is not appropriate for "
|
||||
"your model, please set `tokenizer.chat_template` to an appropriate template. "
|
||||
"See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
|
||||
)
|
||||
return "{% for message in messages %}" "{{ message.content }}{{ eos_token }}" "{% endfor %}"
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"add_prefix_space": false,
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "</s>",
|
||||
"model_max_length": 1000000000000000019884624838656,
|
||||
"pad_token": "<pad>",
|
||||
"padding_side": "left",
|
||||
"tokenizer_class": "BloomTokenizer",
|
||||
"unk_token": "<unk>"
|
||||
}
|
Binary file not shown.
Loading…
Reference in New Issue