Compare commits
10 Commits
5e6b747baf
...
7d16743184
Author | SHA1 | Date |
---|---|---|
Colin | 7d16743184 | |
周以晴 | b655153ec7 | |
yiqing-zhou | 9f8f9ecc89 | |
周以晴 | e8d543558c | |
yiqing-zhou | fcb93e52c4 | |
yiqing-zhou | b76d333f39 | |
周以晴 | 10a88a5012 | |
Yiqing-Zhou | 30df20402d | |
Yiqing-Zhou | 6827898339 | |
Yiqing-Zhou | 216bc4643c |
|
@ -4,25 +4,15 @@
|
||||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
|
||||||
"name": "Python: train",
|
|
||||||
"type": "python",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "train.py",
|
|
||||||
"args": [
|
|
||||||
"--dataset_name", "wikitext:wikitext-2-v1",
|
|
||||||
],
|
|
||||||
"console": "integratedTerminal",
|
|
||||||
"justMyCode": true
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "Python: generate",
|
"name": "Python: generate",
|
||||||
"type": "python",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "generate.py",
|
"program": "${file}",
|
||||||
"args": [],
|
"args": [],
|
||||||
|
"cwd": "${fileDirname}",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": false
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
46
README.md
46
README.md
|
@ -1,9 +1,51 @@
|
||||||
# GPT-Pretrain
|
# GPT-Pretrain
|
||||||
|
|
||||||
## Usage
|
# Usage
|
||||||
|
|
||||||
|
## Make it simple
|
||||||
|
|
||||||
```
|
```
|
||||||
python lit_train.py --model_name gpt2
|
python lit_train.py --model_name gpt2 --use_tril_attention_mask
|
||||||
python lit_export.py --version 0
|
python lit_export.py --version 0
|
||||||
python generate.py --model_name_or_path exports/version_0 --tokenizer_name_or_path gpt2
|
python generate.py --model_name_or_path exports/version_0 --tokenizer_name_or_path gpt2
|
||||||
```
|
```
|
||||||
|
> :memo: **Note:** Training with a "--use_tril_attention_mask" is recommended. However, huggingface model implementions might not support 2D attention mask. You may write a custom model to support 2D attention mask, just like what I did in [custom_models/gpt2](https://github.com/Yiqing-Zhou/gpt-pretrain/tree/main/custom_models/gpt2).
|
||||||
|
|
||||||
|
## Train on multiple GPUs
|
||||||
|
|
||||||
|
```
|
||||||
|
python lit_train.py --model_name gpt2 --use_tril_attention_mask --strategy fsdp # default and recommended
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python lit_train.py --model_name gpt2 --use_tril_attention_mask --strategy deepspeed
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
python lit_train.py --model_name gpt2 --use_tril_attention_mask --strategy ddp
|
||||||
|
```
|
||||||
|
|
||||||
|
## Reduce CUDA memory cost
|
||||||
|
|
||||||
|
- half precision
|
||||||
|
```
|
||||||
|
python lit_train.py --model_name gpt2 --use_tril_attention_mask --bf16
|
||||||
|
```
|
||||||
|
```
|
||||||
|
python lit_train.py --model_name gpt2 --use_tril_attention_mask --fp16
|
||||||
|
```
|
||||||
|
- smaller batch size & accumulate grad batches
|
||||||
|
```
|
||||||
|
python lit_train.py --model_name gpt2 --use_tril_attention_mask --bf16 \
|
||||||
|
--train_batch_size 2 --val_batch_size 4 --accumulate_grad_batches 128
|
||||||
|
```
|
||||||
|
- cpu_offload
|
||||||
|
```
|
||||||
|
python lit_train.py --model_name gpt2 --use_tril_attention_mask --bf16 \
|
||||||
|
--strategy fsdp_cpu_offload
|
||||||
|
```
|
||||||
|
```
|
||||||
|
python lit_train.py --model_name gpt2 --use_tril_attention_mask --bf16 \
|
||||||
|
--strategy deepspeed_stage_3_offload
|
||||||
|
```
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
import importlib
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from transformers.models.auto import auto_factory, configuration_auto
|
||||||
|
|
||||||
|
CONFIG_MAPPING_NAMES = OrderedDict([])
|
||||||
|
|
||||||
|
|
||||||
|
def register_custom_configs():
|
||||||
|
for model_type, map_name in CONFIG_MAPPING_NAMES.items():
|
||||||
|
module_name = configuration_auto.model_type_to_module_name(model_type)
|
||||||
|
module = importlib.import_module(f".{module_name}", "custom_models")
|
||||||
|
mapping = getattr(module, map_name)
|
||||||
|
configuration_auto.AutoConfig.register(model_type, mapping)
|
||||||
|
|
||||||
|
|
||||||
|
class _LazyAutoMapping(auto_factory._LazyAutoMapping):
|
||||||
|
def _load_attr_from_module(self, model_type, attr):
|
||||||
|
module_name = auto_factory.model_type_to_module_name(model_type)
|
||||||
|
if module_name not in self._modules:
|
||||||
|
self._modules[module_name] = importlib.import_module(f".{module_name}", "custom_models")
|
||||||
|
return auto_factory.getattribute_from_module(self._modules[module_name], attr)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
("gpt2", "GPT2Model"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
("gpt2", "GPT2LMHeadModel"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_MAPPING = _LazyAutoMapping(
|
||||||
|
{**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_MAPPING_NAMES
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||||
|
{**configuration_auto.CONFIG_MAPPING_NAMES, **CONFIG_MAPPING_NAMES}, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModel(auto_factory._BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
AutoModel = auto_factory.auto_class_update(AutoModel)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForCausalLM(auto_factory._BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForCausalLM = auto_factory.auto_class_update(AutoModelForCausalLM)
|
|
@ -0,0 +1 @@
|
||||||
|
from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model
|
|
@ -0,0 +1,355 @@
|
||||||
|
"""Override transformers GPT2 to support tril attention mask"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
import transformers
|
||||||
|
from torch import nn
|
||||||
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import (
|
||||||
|
_CHECKPOINT_FOR_DOC,
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
GPT2_INPUTS_DOCSTRING,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
ModelOutput,
|
||||||
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2Model(transformers.models.gpt2.GPT2Model):
|
||||||
|
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||||
|
elif input_ids is not None:
|
||||||
|
input_shape = input_ids.size()
|
||||||
|
input_ids = input_ids.view(-1, input_shape[-1])
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
input_shape = inputs_embeds.size()[:-1]
|
||||||
|
batch_size = inputs_embeds.shape[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
|
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||||
|
if position_ids is not None:
|
||||||
|
position_ids = position_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
if past_key_values is None:
|
||||||
|
past_length = 0
|
||||||
|
past_key_values = tuple([None] * len(self.h))
|
||||||
|
else:
|
||||||
|
past_length = past_key_values[0][0].size(-2)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_length,
|
||||||
|
input_shape[-1] + past_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||||
|
|
||||||
|
# GPT2Attention mask.
|
||||||
|
if attention_mask is not None:
|
||||||
|
if batch_size <= 0:
|
||||||
|
raise ValueError("batch_size has to be defined and > 0")
|
||||||
|
if attention_mask.dim() == 2:
|
||||||
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
elif attention_mask.dim() == 3:
|
||||||
|
attention_mask = attention_mask[:, None, ...]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3")
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
|
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||||
|
|
||||||
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
|
(
|
||||||
|
encoder_batch_size,
|
||||||
|
encoder_sequence_length,
|
||||||
|
_,
|
||||||
|
) = encoder_hidden_states.size()
|
||||||
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
|
if encoder_attention_mask is None:
|
||||||
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||||
|
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
else:
|
||||||
|
encoder_attention_mask = None
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
position_embeds = self.wpe(position_ids)
|
||||||
|
hidden_states = inputs_embeds + position_embeds
|
||||||
|
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
|
hidden_states = hidden_states + token_type_embeds
|
||||||
|
|
||||||
|
hidden_states = self.drop(hidden_states)
|
||||||
|
|
||||||
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
logger.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
presents = () if use_cache else None
|
||||||
|
all_self_attentions = () if output_attentions else None
|
||||||
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
# Model parallel
|
||||||
|
if self.model_parallel:
|
||||||
|
torch.cuda.set_device(hidden_states.device)
|
||||||
|
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
||||||
|
if layer_past is not None:
|
||||||
|
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
||||||
|
# Ensure that attention_mask is always on the same device as hidden_states
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
|
if isinstance(head_mask, torch.Tensor):
|
||||||
|
head_mask = head_mask.to(hidden_states.device)
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
# None for past_key_value
|
||||||
|
return module(*inputs, use_cache, output_attentions)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
outputs = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(block),
|
||||||
|
hidden_states,
|
||||||
|
None,
|
||||||
|
attention_mask,
|
||||||
|
head_mask[i],
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outputs = block(
|
||||||
|
hidden_states,
|
||||||
|
layer_past=layer_past,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask[i],
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
if use_cache is True:
|
||||||
|
presents = presents + (outputs[1],)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||||
|
if self.config.add_cross_attention:
|
||||||
|
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
||||||
|
|
||||||
|
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||||
|
if self.model_parallel:
|
||||||
|
for k, v in self.device_map.items():
|
||||||
|
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
||||||
|
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
||||||
|
|
||||||
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
# Add last hidden state
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
presents,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attentions,
|
||||||
|
all_cross_attentions,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=presents,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attentions,
|
||||||
|
cross_attentions=all_cross_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.transformer = GPT2Model(config)
|
||||||
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
# Model parallel
|
||||||
|
self.model_parallel = False
|
||||||
|
self.device_map = None
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
||||||
|
token_type_ids = kwargs.get("token_type_ids", None)
|
||||||
|
# only last token for inputs_ids if past is defined in kwargs
|
||||||
|
if past_key_values:
|
||||||
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
if token_type_ids is not None:
|
||||||
|
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
position_ids = kwargs.get("position_ids", None)
|
||||||
|
|
||||||
|
if attention_mask is not None and position_ids is None:
|
||||||
|
# create position_ids on the fly for batch generation
|
||||||
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
|
if past_key_values:
|
||||||
|
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||||
|
if inputs_embeds is not None and past_key_values is None:
|
||||||
|
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||||
|
else:
|
||||||
|
model_inputs = {"input_ids": input_ids}
|
||||||
|
|
||||||
|
model_inputs.update(
|
||||||
|
{
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"use_cache": kwargs.get("use_cache"),
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return model_inputs
|
||||||
|
|
||||||
|
def _update_model_kwargs_for_generation(
|
||||||
|
self,
|
||||||
|
outputs: ModelOutput,
|
||||||
|
model_kwargs: Dict[str, Any],
|
||||||
|
is_encoder_decoder: bool = False,
|
||||||
|
standardize_cache_format: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
# update past_key_values
|
||||||
|
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
||||||
|
outputs, standardize_cache_format=standardize_cache_format
|
||||||
|
)
|
||||||
|
|
||||||
|
# update token_type_ids with last value
|
||||||
|
if "token_type_ids" in model_kwargs:
|
||||||
|
token_type_ids = model_kwargs["token_type_ids"]
|
||||||
|
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
|
# update position_ids
|
||||||
|
if "position_ids" in model_kwargs:
|
||||||
|
position_ids = model_kwargs["position_ids"]
|
||||||
|
if model_kwargs["past_key_values"] is not None:
|
||||||
|
model_kwargs["position_ids"] = (position_ids[:, -1] + 1).unsqueeze(-1)
|
||||||
|
else:
|
||||||
|
model_kwargs["position_ids"] = torch.cat(
|
||||||
|
[position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
if not is_encoder_decoder:
|
||||||
|
# update attention mask
|
||||||
|
if "attention_mask" in model_kwargs:
|
||||||
|
attention_mask = model_kwargs["attention_mask"]
|
||||||
|
if attention_mask.dim() == 2:
|
||||||
|
model_kwargs["attention_mask"] = torch.cat(
|
||||||
|
[
|
||||||
|
attention_mask,
|
||||||
|
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
elif attention_mask.dim() == 3:
|
||||||
|
attention_mask = attention_mask[:, -1, :]
|
||||||
|
attention_mask = torch.cat(
|
||||||
|
[
|
||||||
|
attention_mask,
|
||||||
|
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
model_kwargs["attention_mask"] = attention_mask
|
||||||
|
else:
|
||||||
|
raise ValueError(f"attention_mask.dim() is {attention_mask.dim()}, should be 2 or 3")
|
||||||
|
else:
|
||||||
|
# update decoder attention mask
|
||||||
|
if "decoder_attention_mask" in model_kwargs:
|
||||||
|
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
||||||
|
model_kwargs["decoder_attention_mask"] = torch.cat(
|
||||||
|
[
|
||||||
|
decoder_attention_mask,
|
||||||
|
decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1)),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_kwargs
|
16
generate.py
16
generate.py
|
@ -13,15 +13,11 @@ def eval_prompts(
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
use_tril_attention_mask: bool = False,
|
use_tril_attention_mask: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
inputs = tokenizer(
|
inputs = tokenizer(prompts, padding=True, return_tensors='pt', return_attention_mask=True)
|
||||||
prompts, padding=True, return_tensors='pt', return_attention_mask=True
|
|
||||||
)
|
|
||||||
inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1
|
inputs['position_ids'] = inputs.attention_mask.cumsum(-1) - 1
|
||||||
inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1)
|
inputs['position_ids'].masked_fill_(inputs.attention_mask == 0, 1)
|
||||||
if use_tril_attention_mask:
|
if use_tril_attention_mask:
|
||||||
inputs['attention_mask'] = (
|
inputs['attention_mask'] = (inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2)).tril()
|
||||||
inputs.attention_mask.unsqueeze(1) * inputs.attention_mask.unsqueeze(2)
|
|
||||||
).tril()
|
|
||||||
inputs = inputs.to(model.device)
|
inputs = inputs.to(model.device)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
output_ids = model.generate(
|
output_ids = model.generate(
|
||||||
|
@ -32,9 +28,7 @@ def eval_prompts(
|
||||||
eos_token_id=tokenizer.eos_token_id,
|
eos_token_id=tokenizer.eos_token_id,
|
||||||
early_stopping=True,
|
early_stopping=True,
|
||||||
)
|
)
|
||||||
completes = tokenizer.batch_decode(
|
completes = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
return completes
|
return completes
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,9 +75,7 @@ if __name__ == '__main__':
|
||||||
"这是一个最好的时代,这是一个最坏的时代。",
|
"这是一个最好的时代,这是一个最坏的时代。",
|
||||||
"这是一个最好的时代,这是一个最坏的",
|
"这是一个最好的时代,这是一个最坏的",
|
||||||
]
|
]
|
||||||
completes = eval_prompts(
|
completes = eval_prompts(model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask)
|
||||||
model, tokenizer, prompts, use_tril_attention_mask=args.use_tril_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
for prompt, complete in zip(prompts, completes):
|
for prompt, complete in zip(prompts, completes):
|
||||||
print("[p]", prompt)
|
print("[p]", prompt)
|
||||||
|
|
|
@ -26,8 +26,6 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt"))
|
checkpoint_file_path = next(lightning_logs_dir_path.glob("checkpoints/*.ckpt"))
|
||||||
|
|
||||||
lit_module = LitModule.load_from_checkpoint(
|
lit_module = LitModule.load_from_checkpoint(checkpoint_file_path, map_location='cpu')
|
||||||
checkpoint_file_path, map_location='cpu'
|
|
||||||
)
|
|
||||||
model: PreTrainedModel = lit_module.__core_module__
|
model: PreTrainedModel = lit_module.__core_module__
|
||||||
model.save_pretrained(exports_dir_path)
|
model.save_pretrained(exports_dir_path)
|
||||||
|
|
|
@ -27,9 +27,7 @@ class LitModule(pl.LightningModule):
|
||||||
)
|
)
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_batch_tril_matrix(
|
def get_batch_tril_matrix(self, block_size: int, batch_size: Optional[int] = None) -> torch.Tensor:
|
||||||
self, block_size: int, batch_size: Optional[int] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
matrix = torch.ones(block_size, block_size).tril()
|
matrix = torch.ones(block_size, block_size).tril()
|
||||||
if batch_size is not None:
|
if batch_size is not None:
|
||||||
matrix = matrix.repeat(batch_size, 1, 1)
|
matrix = matrix.repeat(batch_size, 1, 1)
|
||||||
|
@ -42,9 +40,7 @@ class LitModule(pl.LightningModule):
|
||||||
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
def training_step(self, batch: Dict[str, torch.Tensor], batch_idx):
|
||||||
batch_size, block_size = batch['input_ids'].shape
|
batch_size, block_size = batch['input_ids'].shape
|
||||||
if self.use_tril_attention_mask:
|
if self.use_tril_attention_mask:
|
||||||
batch['attention_mask'] = self.get_batch_tril_matrix(
|
batch['attention_mask'] = self.get_batch_tril_matrix(block_size, batch_size=batch_size).to(self.device)
|
||||||
block_size, batch_size=batch_size
|
|
||||||
).to(self.device)
|
|
||||||
outputs = self.llm(**batch, return_dict=True)
|
outputs = self.llm(**batch, return_dict=True)
|
||||||
loss = outputs.loss
|
loss = outputs.loss
|
||||||
|
|
||||||
|
@ -80,9 +76,7 @@ class LitModule(pl.LightningModule):
|
||||||
self.trainer.model.parameters(), lr=self.learning_rate
|
self.trainer.model.parameters(), lr=self.learning_rate
|
||||||
)
|
)
|
||||||
return optimizer
|
return optimizer
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
|
||||||
self.trainer.model.parameters(), lr=self.learning_rate
|
|
||||||
)
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
def configure_callbacks(self):
|
def configure_callbacks(self):
|
||||||
|
|
25
lit_train.py
25
lit_train.py
|
@ -25,16 +25,12 @@ def split_raw_dataset(
|
||||||
if 'validation' in raw_dataset:
|
if 'validation' in raw_dataset:
|
||||||
train_dataset, val_dataset = raw_dataset['train'], raw_dataset['validation']
|
train_dataset, val_dataset = raw_dataset['train'], raw_dataset['validation']
|
||||||
else:
|
else:
|
||||||
raw_dataset = raw_dataset['train'].train_test_split(
|
raw_dataset = raw_dataset['train'].train_test_split(test_size=0.05, seed=args.seed)
|
||||||
test_size=0.05, seed=args.seed
|
|
||||||
)
|
|
||||||
train_dataset, val_dataset = raw_dataset['train'], raw_dataset['test']
|
train_dataset, val_dataset = raw_dataset['train'], raw_dataset['test']
|
||||||
return train_dataset, val_dataset
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
|
||||||
def process_dataset(
|
def process_dataset(dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer) -> datasets.Dataset:
|
||||||
dataset: datasets.Dataset, tokenizer: PreTrainedTokenizer
|
|
||||||
) -> datasets.Dataset:
|
|
||||||
def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding:
|
def group_texts(examples: Dict[str, list], block_size: int = 512) -> BatchEncoding:
|
||||||
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
||||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||||
|
@ -110,13 +106,13 @@ def parse_args():
|
||||||
"--train_batch_size",
|
"--train_batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
help="Batch size of training",
|
help="Batch size of training",
|
||||||
default=8,
|
default=2,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--val_batch_size",
|
"--val_batch_size",
|
||||||
type=int,
|
type=int,
|
||||||
help="Batch size of validating",
|
help="Batch size of validating",
|
||||||
default=16,
|
default=2,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--accumulate_grad_batches",
|
"--accumulate_grad_batches",
|
||||||
|
@ -128,7 +124,7 @@ def parse_args():
|
||||||
"--num_proc",
|
"--num_proc",
|
||||||
type=str,
|
type=str,
|
||||||
help="Number of data processes",
|
help="Number of data processes",
|
||||||
default=16,
|
default=1,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_epochs",
|
"--max_epochs",
|
||||||
|
@ -140,7 +136,7 @@ def parse_args():
|
||||||
"--strategy",
|
"--strategy",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name of pytorch lightning distribution strategy",
|
help="Name of pytorch lightning distribution strategy",
|
||||||
default='fsdp',
|
default='ddp',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resume_from_ckpt_path",
|
"--resume_from_ckpt_path",
|
||||||
|
@ -167,9 +163,7 @@ if __name__ == '__main__':
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
# lightning module
|
# lightning module
|
||||||
lit_module = LitModule(
|
lit_module = LitModule(args.model_name, args.learning_rate, args.use_tril_attention_mask)
|
||||||
args.model_name, args.learning_rate, args.use_tril_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# datasets
|
# datasets
|
||||||
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
tokenizer = load_tokenizer(args.tokenizer_name_or_path)
|
||||||
|
@ -203,8 +197,11 @@ if __name__ == '__main__':
|
||||||
persistent_workers=True,
|
persistent_workers=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ne = next(train_dataloader._get_iterator())
|
||||||
|
print((ne["input_ids"]-ne["labels"]).numpy().tolist())
|
||||||
|
|
||||||
# trainer
|
# trainer
|
||||||
apply_all_patches()
|
# apply_all_patches()
|
||||||
torch.set_float32_matmul_precision('medium')
|
torch.set_float32_matmul_precision('medium')
|
||||||
if args.bf16:
|
if args.bf16:
|
||||||
precision = 'bf16-mixed'
|
precision = 'bf16-mixed'
|
||||||
|
|
35
utils.py
35
utils.py
|
@ -10,9 +10,19 @@ from transformers import (
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import custom_models
|
||||||
|
|
||||||
def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
|
custom_models.register_custom_configs()
|
||||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
|
||||||
|
|
||||||
|
def init_model(model_name: str) -> PreTrainedModel:
|
||||||
|
config = AutoConfig.for_model(model_type=model_name)
|
||||||
|
|
||||||
|
if model_name in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
model = custom_models.AutoModelForCausalLM.from_config(config)
|
||||||
|
elif model_name in custom_models.MODEL_MAPPING_NAMES:
|
||||||
|
model = custom_models.AutoModel.from_config(config)
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -21,21 +31,22 @@ def init_model(model_name: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
def load_model(model_name_or_path: Union[str, os.PathLike]) -> PreTrainedModel:
|
||||||
|
config = AutoConfig.from_pretrained(model_name_or_path)
|
||||||
|
|
||||||
|
if config.model_type in custom_models.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
|
model = custom_models.AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config)
|
||||||
|
elif config.model_type in custom_models.MODEL_MAPPING_NAMES:
|
||||||
|
model = custom_models.AutoModel.from_pretrained(model_name_or_path, config=config)
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
|
||||||
model_name_or_path, trust_remote_code=True
|
|
||||||
)
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
|
model = AutoModel.from_pretrained(model_name_or_path, config=config, trust_remote_code=True)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(
|
def load_tokenizer(tokenizer_name_or_path: Union[str, os.PathLike]) -> PreTrainedTokenizer:
|
||||||
tokenizer_name_or_path: Union[str, os.PathLike]
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, padding_side='left', trust_remote_code=True)
|
||||||
) -> PreTrainedTokenizer:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
tokenizer_name_or_path, padding_side='left', trust_remote_code=True
|
|
||||||
)
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
Loading…
Reference in New Issue