Format code.

This commit is contained in:
Colin 2024-01-01 10:20:04 +08:00
parent bde8f71a7f
commit b3ef30aa1a
4 changed files with 46 additions and 90 deletions

View File

@ -3,6 +3,7 @@ from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig): class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm" model_type = "chatglm"
def __init__( def __init__(
self, self,
num_layers=28, num_layers=28,

View File

@ -26,9 +26,7 @@ from tools import show
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, original_impl=False, device=None, dtype=None): def __init__(self, dim: int, original_impl=False, device=None, dtype=None):
super().__init__() super().__init__()
inv_freq = 1.0 / ( inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)
)
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self.dim = dim self.dim = dim
self.original_impl = original_impl self.original_impl = original_impl
@ -37,13 +35,7 @@ class RotaryEmbedding(nn.Module):
dtype = self.inv_freq.dtype dtype = self.inv_freq.dtype
device = self.inv_freq.device device = self.inv_freq.device
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / ( theta = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
base
** (
torch.arange(0, self.dim, 2, dtype=torch.float, device=device)
/ self.dim
)
)
# Create position indexes `[0, 1, ..., max_seq_len - 1]` # Create position indexes `[0, 1, ..., max_seq_len - 1]`
seq_idx = torch.arange(max_seq_len, dtype=torch.float, device=device) seq_idx = torch.arange(max_seq_len, dtype=torch.float, device=device)
# Calculate the product of position index and $\theta_i$ # Calculate the product of position index and $\theta_i$
@ -58,9 +50,7 @@ class RotaryEmbedding(nn.Module):
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__() super().__init__()
self.weight = torch.nn.Parameter( self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
torch.empty(normalized_shape, device=device, dtype=dtype)
)
self.eps = eps self.eps = eps
def forward(self, hidden_states: torch.Tensor): def forward(self, hidden_states: torch.Tensor):
@ -80,9 +70,7 @@ class CoreAttention(torch.nn.Module):
projection_size = config.kv_channels * config.num_attention_heads projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
self.hidden_size_per_partition = projection_size self.hidden_size_per_partition = projection_size
self.hidden_size_per_attention_head = ( self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
projection_size // config.num_attention_heads
)
self.num_attention_heads_per_partition = config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads
coeff = None coeff = None
@ -94,17 +82,13 @@ class CoreAttention(torch.nn.Module):
self.attention_dropout = torch.nn.Dropout(config.attention_dropout) self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer, value_layer): def forward(self, query_layer, key_layer, value_layer):
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
]
if query_layer.shape[2] == key_layer.shape[2]: if query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention( context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, is_causal=True query_layer, key_layer, value_layer, is_causal=True
) )
context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + ( new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
self.hidden_size_per_partition,
)
context_layer = context_layer.reshape(*new_context_layer_shape) context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer return context_layer
@ -114,16 +98,13 @@ class SelfAttention(torch.nn.Module):
super(SelfAttention, self).__init__() super(SelfAttention, self).__init__()
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.projection_size = config.kv_channels * config.num_attention_heads self.projection_size = config.kv_channels * config.num_attention_heads
self.hidden_size_per_attention_head = ( self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
self.projection_size // config.num_attention_heads
)
self.num_attention_heads_per_partition = config.num_attention_heads self.num_attention_heads_per_partition = config.num_attention_heads
self.multi_query_attention = config.multi_query_attention self.multi_query_attention = config.multi_query_attention
self.qkv_hidden_size = 3 * self.projection_size self.qkv_hidden_size = 3 * self.projection_size
self.num_multi_query_groups_per_partition = config.multi_query_group_num self.num_multi_query_groups_per_partition = config.multi_query_group_num
self.qkv_hidden_size = ( self.qkv_hidden_size = (
self.projection_size self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
+ 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
) )
self.query_key_value = nn.Linear( self.query_key_value = nn.Linear(
config.hidden_size, config.hidden_size,
@ -163,12 +144,9 @@ class SelfAttention(torch.nn.Module):
(query_layer, key_layer, value_layer) = mixed_x_layer.split( (query_layer, key_layer, value_layer) = mixed_x_layer.split(
[ [
self.num_attention_heads_per_partition self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
* self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
* self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition
* self.hidden_size_per_attention_head,
], ],
dim=-1, dim=-1,
) )
@ -204,8 +182,7 @@ class SelfAttention(torch.nn.Module):
-1, -1,
-1, -1,
-1, -1,
self.num_attention_heads_per_partition self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
// self.num_multi_query_groups_per_partition,
-1, -1,
) )
key_layer = key_layer.contiguous().view( key_layer = key_layer.contiguous().view(
@ -220,8 +197,7 @@ class SelfAttention(torch.nn.Module):
-1, -1,
-1, -1,
-1, -1,
self.num_attention_heads_per_partition self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition,
// self.num_multi_query_groups_per_partition,
-1, -1,
) )
value_layer = value_layer.contiguous().view( value_layer = value_layer.contiguous().view(
@ -292,9 +268,7 @@ class GLMBlock(torch.nn.Module):
super(GLMBlock, self).__init__() super(GLMBlock, self).__init__()
self.layer_number = layer_number self.layer_number = layer_number
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
config.apply_residual_connection_post_layernorm
)
self.fp32_residual_connection = config.fp32_residual_connection self.fp32_residual_connection = config.fp32_residual_connection
@ -326,9 +300,7 @@ class GLMBlock(torch.nn.Module):
attention_output = self.self_attention(layernorm_output, rotary_pos_emb) attention_output = self.self_attention(layernorm_output, rotary_pos_emb)
residual = hidden_states residual = hidden_states
layernorm_input = torch.nn.functional.dropout( layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
attention_output, p=self.hidden_dropout, training=self.training
)
layernorm_input = residual + layernorm_input layernorm_input = residual + layernorm_input
# Layer norm post the self attention. # Layer norm post the self attention.
@ -339,9 +311,7 @@ class GLMBlock(torch.nn.Module):
residual = layernorm_input residual = layernorm_input
output = torch.nn.functional.dropout( output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
mlp_output, p=self.hidden_dropout, training=self.training
)
output = residual + output output = residual + output
return output return output
@ -409,9 +379,7 @@ class ChatGLMModel(nn.Module):
# Rotary positional embeddings # Rotary positional embeddings
self.seq_length = config.seq_length self.seq_length = config.seq_length
rotary_dim = ( rotary_dim = (
config.hidden_size // config.num_attention_heads config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
if config.kv_channels is None
else config.kv_channels
) )
self.rotary_pos_emb = RotaryEmbedding( self.rotary_pos_emb = RotaryEmbedding(
@ -438,9 +406,7 @@ class ChatGLMModel(nn.Module):
tokenizer=None, tokenizer=None,
): ):
output_hidden_states = ( output_hidden_states = (
output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
) )
inputs_embeds = self.embedding(input_ids) inputs_embeds = self.embedding(input_ids)
@ -475,23 +441,17 @@ class ChatGLMForConditionalGeneration(nn.Module):
self.warnings_issued = {} self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) self.generation_config = GenerationConfig.from_model_config(config)
def from_pretrained( def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]):
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]
):
load_in_8bit = False load_in_8bit = False
load_in_4bit = False load_in_4bit = False
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
resolved_archive_file = os.path.join( resolved_archive_file = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json")
pretrained_model_name_or_path, "pytorch_model.bin.index.json"
)
print(f"loading weights file {resolved_archive_file}") print(f"loading weights file {resolved_archive_file}")
with open(resolved_archive_file, "r") as f: with open(resolved_archive_file, "r") as f:
index = json.loads(f.read()) index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values())) shard_filenames = sorted(set(index["weight_map"].values()))
resolved_archive_file = [ resolved_archive_file = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames]
os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames
]
model = cls._load_pretrained_model(resolved_archive_file) model = cls._load_pretrained_model(resolved_archive_file)
model.is_loaded_in_4bit = load_in_4bit model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit model.is_loaded_in_8bit = load_in_8bit
@ -524,21 +484,15 @@ class ChatGLMForConditionalGeneration(nn.Module):
model_to_load = cls model_to_load = cls
error_msgs = [] error_msgs = []
if len(resolved_archive_file) > 1: if len(resolved_archive_file) > 1:
resolved_archive_file = tqdm_lib.tqdm( resolved_archive_file = tqdm_lib.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
resolved_archive_file, desc="Loading checkpoint shards"
)
for shard_file in resolved_archive_file: for shard_file in resolved_archive_file:
state_dict = torch.load(shard_file, map_location="cpu") state_dict = torch.load(shard_file, map_location="cpu")
error_msgs += cls._load_state_dict_into_model( error_msgs += cls._load_state_dict_into_model(model_to_load, state_dict, start_prefix)
model_to_load, state_dict, start_prefix
)
del state_dict # force memory release del state_dict # force memory release
gc.collect() gc.collect()
print( print(f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n")
f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n"
)
return cls return cls
@torch.inference_mode() @torch.inference_mode()
@ -556,9 +510,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
generation_config = copy.deepcopy(self.generation_config) generation_config = copy.deepcopy(self.generation_config)
inputs_tensor = inputs["input_ids"] inputs_tensor = inputs["input_ids"]
input_ids = inputs_tensor.repeat_interleave( input_ids = inputs_tensor.repeat_interleave(generation_config.num_return_sequences, dim=0)
generation_config.num_return_sequences, dim=0
)
outputs = self.sample( outputs = self.sample(
input_ids, input_ids,
@ -585,17 +537,13 @@ class ChatGLMForConditionalGeneration(nn.Module):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
isFinished = torch.zeros( isFinished = torch.zeros(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
# token_count = 0 # token_count = 0
while True: while True:
input_ids_in = input_ids input_ids_in = input_ids
batch_size, seq_length = input_ids_in.shape batch_size, seq_length = input_ids_in.shape
position_ids_in = ( position_ids_in = (
torch.arange(seq_length, dtype=torch.long, device=input_ids.device) torch.arange(seq_length, dtype=torch.long, device=input_ids.device).unsqueeze(0).repeat(batch_size, 1)
.unsqueeze(0)
.repeat(batch_size, 1)
) )
model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in} model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}

View File

@ -19,8 +19,17 @@ class SPTokenizer:
self.pad_id: int = self.sp_model.unk_id() self.pad_id: int = self.sp_model.unk_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop", "<|system|>", "<|user|>", "<|assistant|>", special_tokens = [
"<|observation|>"] "[MASK]",
"[gMASK]",
"[sMASK]",
"sop",
"eop",
"<|system|>",
"<|user|>",
"<|assistant|>",
"<|observation|>",
]
self.special_tokens = {} self.special_tokens = {}
self.index_special_tokens = {} self.index_special_tokens = {}
for token in special_tokens: for token in special_tokens:
@ -59,7 +68,7 @@ class SPTokenizer:
return text return text
def convert_token_to_id(self, token): def convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
if token in self.special_tokens: if token in self.special_tokens:
return self.special_tokens[token] return self.special_tokens[token]
return self.sp_model.PieceToId(token) return self.sp_model.PieceToId(token)
@ -86,7 +95,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
self.special_tokens = { self.special_tokens = {
"<bos>": self.tokenizer.bos_id, "<bos>": self.tokenizer.bos_id,
"<eos>": self.tokenizer.eos_id, "<eos>": self.tokenizer.eos_id,
"<pad>": self.tokenizer.pad_id "<pad>": self.tokenizer.pad_id,
} }
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
@ -121,7 +130,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return self.tokenizer.n_words return self.tokenizer.n_words
def get_vocab(self): def get_vocab(self):
""" Returns vocab as a dict """ """Returns vocab as a dict"""
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder) vocab.update(self.added_tokens_encoder)
return vocab return vocab
@ -130,7 +139,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
return self.tokenizer.tokenize(text) return self.tokenizer.tokenize(text)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """ """Converts a token (str) in an id using the vocab."""
return self.tokenizer.convert_token_to_id(token) return self.tokenizer.convert_token_to_id(token)
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):

View File

@ -24,8 +24,7 @@ tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
a = tokenizer.encode("") a = tokenizer.encode("")
b = tokenizer.decode([236,173,140]) b = tokenizer.decode([236, 173, 140])
token = [] token = []
@ -49,4 +48,3 @@ show.DumpListToFile(token, "generated/token.log")
# # name = "generated/next_tokens" + str(token_count) + "_" + response + "_.png" # # name = "generated/next_tokens" + str(token_count) + "_" + response + "_.png"
# # show.DumpTensorToImage(next_token_logits[0], name) # # show.DumpTensorToImage(next_token_logits[0], name)
# # token_count = token_count + 1 # # token_count = token_count + 1