Update code.

This commit is contained in:
Colin 2023-12-25 17:26:19 +08:00
parent fa7078b72d
commit 0bc7bc90b1
3 changed files with 36 additions and 119 deletions

View File

@ -5,7 +5,7 @@
input_ids = tokenizer.build_chat_input(query, history=history, role=role) input_ids = tokenizer.build_chat_input(query, history=history, role=role)
for for:
input_ids -> [1, 6] 1:batch_num 6:sequence_length input_ids -> [1, 6] 1:batch_num 6:sequence_length
inputs_embeds -> [6, 1, 4096] 4096:hidden_size inputs_embeds -> [6, 1, 4096] 4096:hidden_size
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
@ -40,19 +40,6 @@ Linear(hidden_states) no bias -> [6, 1, 27392]
silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696]) silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
Linear(intermediate_parallel) no bias -> [6, 1, 4096] Linear(intermediate_parallel) no bias -> [6, 1, 4096]
## core_attention
query_layer=query_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
key_layer=key_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
value_layer=value_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128]
softmax(QK^T/sqrt(in_dim))V
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
context_layer = context_layer.permute(2, 0, 1, 3)
context_layer = context_layer.reshape() -> [6, 1, 4096]
## self_attention ## self_attention
hidden_states: [s, b, h] hidden_states: [s, b, h]
@ -74,7 +61,15 @@ value_layer = value_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128]
value_layer = value_layer.expand -> [6, 1, 2, 16, 128] value_layer = value_layer.expand -> [6, 1, 2, 16, 128]
value_layer = value_layer.contiguous().view -> [6, 1, 32, 128] value_layer = value_layer.contiguous().view -> [6, 1, 32, 128]
context_layer = self.core_attention(query_layer, key_layer, value_layer) -> [6, 1, 4096] query_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
key_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
value_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128]
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1)
y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
context_layer = context_layer.permute(2, 0, 1, 3).reshape() -> [6, 1, 4096]
return Linear(context_layer) -> [6, 1, 4096] return Linear(context_layer) -> [6, 1, 4096]
## GLMBlock ## GLMBlock

View File

@ -22,8 +22,6 @@ from transformers.generation import GenerationConfig
from chatglm import ChatGLMConfig from chatglm import ChatGLMConfig
from tools import show from tools import show
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, original_impl=False, device=None, dtype=None): def __init__(self, dim: int, original_impl=False, device=None, dtype=None):
@ -487,60 +485,26 @@ class ChatGLMForConditionalGeneration(nn.Module):
self.generation_config = GenerationConfig.from_model_config(config) self.generation_config = GenerationConfig.from_model_config(config)
def from_pretrained( def from_pretrained(
cls, cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
**kwargs,
): ):
state_dict = kwargs.pop("state_dict", None) load_in_8bit = False
_ = kwargs.pop("mirror", None) load_in_4bit = False
load_in_8bit = kwargs.pop("load_in_8bit", False)
load_in_4bit = kwargs.pop("load_in_4bit", False)
subfolder = kwargs.pop("subfolder", "")
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
sharded_metadata = None
# if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
archive_file = os.path.join( resolved_archive_file = os.path.join(
pretrained_model_name_or_path, pretrained_model_name_or_path, "pytorch_model.bin.index.json"
subfolder,
WEIGHTS_INDEX_NAME,
) )
print(f"loading weights file {archive_file}") print(f"loading weights file {resolved_archive_file}")
resolved_archive_file = archive_file
with open(resolved_archive_file, "r") as f: with open(resolved_archive_file, "r") as f:
index = json.loads(f.read()) index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values())) shard_filenames = sorted(set(index["weight_map"].values()))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
sharded_metadata["weight_map"] = index["weight_map"].copy()
resolved_archive_file = [ resolved_archive_file = [
os.path.join(pretrained_model_name_or_path, subfolder, f) os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames
for f in shard_filenames
] ]
model = cls._load_pretrained_model(resolved_archive_file)
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
config.name_or_path = pretrained_model_name_or_path
config = copy.deepcopy(config)
# make sure we use the model's config since the __init__ call might have copied it
config = cls.config
# restore default dtype
model = cls._load_pretrained_model(
state_dict,
loaded_state_dict_keys,
resolved_archive_file,
pretrained_model_name_or_path,
)
model.is_loaded_in_4bit = load_in_4bit model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit model.is_loaded_in_8bit = load_in_8bit
# Set model in evaluation mode to deactivate DropOut modules by default model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
return model return model
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix): def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
@ -564,40 +528,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
del state_dict del state_dict
return error_msgs return error_msgs
def _load_pretrained_model( def _load_pretrained_model(cls, resolved_archive_file):
cls,
state_dict,
loaded_keys,
resolved_archive_file,
pretrained_model_name_or_path,
):
# Retrieve missing & unexpected_keys
model_state_dict = cls.state_dict()
expected_keys = list(model_state_dict.keys())
loaded_keys = [key for key in loaded_keys]
unexpected_keys = set(loaded_keys) - set(expected_keys)
model_buffers = {n for n, _ in cls.named_buffers()}
unexpected_keys = list(unexpected_keys - model_buffers)
ptrs = collections.defaultdict(list)
for name, tensor in cls.state_dict().items():
ptrs[(tensor.device, storage_ptr(tensor), storage_size(tensor))].append(
name
)
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
_loaded_keys = loaded_keys
for module_name, module in cls.named_modules():
loaded_keys = [
k.replace(f"{module_name}.", "")
for k in _loaded_keys
if k.startswith(f"{module_name}.")
]
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
module._is_hf_initialized = True
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = "" start_prefix = ""
model_to_load = cls model_to_load = cls
error_msgs = [] error_msgs = []
@ -606,26 +537,17 @@ class ChatGLMForConditionalGeneration(nn.Module):
resolved_archive_file, desc="Loading checkpoint shards" resolved_archive_file, desc="Loading checkpoint shards"
) )
for shard_file in resolved_archive_file: for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
state_dict = torch.load(shard_file, map_location="cpu") state_dict = torch.load(shard_file, map_location="cpu")
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
error_msgs += cls._load_state_dict_into_model( error_msgs += cls._load_state_dict_into_model(
model_to_load, state_dict, start_prefix model_to_load, state_dict, start_prefix
) )
# force memory release del state_dict # force memory release
del state_dict
gc.collect() gc.collect()
print( print(
f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n" f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n"
) )
print(
f"All the weights of {cls.__class__.__name__} were initialized from the model checkpoint at"
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
f" was trained on, you can already use {cls.__class__.__name__} for predictions without further"
" training."
)
return cls return cls
@torch.inference_mode() @torch.inference_mode()

View File

@ -37,7 +37,7 @@ init_kwargs["name_or_path"] = pretrained_model_name_or_path
tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs) tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
glm = glm.from_pretrained(pretrained_model_name_or_path, config=config).half().cuda() glm = glm.from_pretrained(pretrained_model_name_or_path).half().cuda()
glm = glm.eval() glm = glm.eval()
query = "colin" query = "colin"
response, history = glm.chat(tokenizer, query, history=[]) response, history = glm.chat(tokenizer, query, history=[])