Update code.
This commit is contained in:
parent
fa7078b72d
commit
0bc7bc90b1
51
Readme.md
51
Readme.md
|
@ -5,24 +5,24 @@
|
|||
|
||||
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
|
||||
|
||||
for
|
||||
input_ids -> [1, 6] 1:batch_num 6:sequence_length
|
||||
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
|
||||
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
|
||||
for:
|
||||
input_ids -> [1, 6] 1:batch_num 6:sequence_length
|
||||
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
|
||||
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for layers : GLMBlock(hidden_states, rotary_pos_emb)
|
||||
hidden_states = RMSNorm(hidden_states)
|
||||
hidden_states = hidden_states[-1:] 截取最后一个sequence
|
||||
lm_logits = self.output_layer(hidden_states)
|
||||
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
||||
hidden_states = inputs_embeds
|
||||
for layers : GLMBlock(hidden_states, rotary_pos_emb)
|
||||
hidden_states = RMSNorm(hidden_states)
|
||||
hidden_states = hidden_states[-1:] 截取最后一个sequence
|
||||
lm_logits = self.output_layer(hidden_states)
|
||||
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
||||
|
||||
probs = softmax(lm_logits) -> [1, 65024]
|
||||
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
|
||||
probs = softmax(lm_logits) -> [1, 65024]
|
||||
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
|
||||
|
||||
if next_tokens == eos_token_id 推理结束退出循环
|
||||
if next_tokens == eos_token_id 推理结束退出循环
|
||||
|
||||
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
|
||||
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
|
||||
|
||||
response = tokenizer.decode(outputs)
|
||||
|
||||
|
@ -40,19 +40,6 @@ Linear(hidden_states) no bias -> [6, 1, 27392]
|
|||
silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
|
||||
Linear(intermediate_parallel) no bias -> [6, 1, 4096]
|
||||
|
||||
## core_attention
|
||||
|
||||
query_layer=query_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||
key_layer=key_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||
value_layer=value_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128]
|
||||
softmax(QK^T/sqrt(in_dim))V
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = F.softmax(att, dim=-1)
|
||||
y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
context_layer = context_layer.reshape() -> [6, 1, 4096]
|
||||
|
||||
## self_attention
|
||||
|
||||
hidden_states: [s, b, h]
|
||||
|
@ -74,7 +61,15 @@ value_layer = value_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128]
|
|||
value_layer = value_layer.expand -> [6, 1, 2, 16, 128]
|
||||
value_layer = value_layer.contiguous().view -> [6, 1, 32, 128]
|
||||
|
||||
context_layer = self.core_attention(query_layer, key_layer, value_layer) -> [6, 1, 4096]
|
||||
query_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||
key_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||
value_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128]
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = F.softmax(att, dim=-1)
|
||||
y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).reshape() -> [6, 1, 4096]
|
||||
|
||||
return Linear(context_layer) -> [6, 1, 4096]
|
||||
|
||||
## GLMBlock
|
||||
|
|
|
@ -22,8 +22,6 @@ from transformers.generation import GenerationConfig
|
|||
from chatglm import ChatGLMConfig
|
||||
from tools import show
|
||||
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, original_impl=False, device=None, dtype=None):
|
||||
|
@ -487,60 +485,26 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
|||
self.generation_config = GenerationConfig.from_model_config(config)
|
||||
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
||||
**kwargs,
|
||||
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]
|
||||
):
|
||||
state_dict = kwargs.pop("state_dict", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
load_in_8bit = False
|
||||
load_in_4bit = False
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# index of the files.
|
||||
sharded_metadata = None
|
||||
|
||||
# if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
archive_file = os.path.join(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
resolved_archive_file = os.path.join(
|
||||
pretrained_model_name_or_path, "pytorch_model.bin.index.json"
|
||||
)
|
||||
print(f"loading weights file {archive_file}")
|
||||
resolved_archive_file = archive_file
|
||||
|
||||
print(f"loading weights file {resolved_archive_file}")
|
||||
with open(resolved_archive_file, "r") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
shard_filenames = sorted(set(index["weight_map"].values()))
|
||||
sharded_metadata = index["metadata"]
|
||||
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
||||
sharded_metadata["weight_map"] = index["weight_map"].copy()
|
||||
|
||||
resolved_archive_file = [
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, f)
|
||||
for f in shard_filenames
|
||||
os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames
|
||||
]
|
||||
|
||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
config = copy.deepcopy(config)
|
||||
# make sure we use the model's config since the __init__ call might have copied it
|
||||
config = cls.config
|
||||
# restore default dtype
|
||||
model = cls._load_pretrained_model(
|
||||
state_dict,
|
||||
loaded_state_dict_keys,
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
)
|
||||
model = cls._load_pretrained_model(resolved_archive_file)
|
||||
model.is_loaded_in_4bit = load_in_4bit
|
||||
model.is_loaded_in_8bit = load_in_8bit
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
|
||||
return model
|
||||
|
||||
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
|
||||
|
@ -564,40 +528,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
|||
del state_dict
|
||||
return error_msgs
|
||||
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
state_dict,
|
||||
loaded_keys,
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
):
|
||||
# Retrieve missing & unexpected_keys
|
||||
model_state_dict = cls.state_dict()
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
loaded_keys = [key for key in loaded_keys]
|
||||
|
||||
unexpected_keys = set(loaded_keys) - set(expected_keys)
|
||||
model_buffers = {n for n, _ in cls.named_buffers()}
|
||||
unexpected_keys = list(unexpected_keys - model_buffers)
|
||||
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in cls.state_dict().items():
|
||||
ptrs[(tensor.device, storage_ptr(tensor), storage_size(tensor))].append(
|
||||
name
|
||||
)
|
||||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
||||
|
||||
_loaded_keys = loaded_keys
|
||||
for module_name, module in cls.named_modules():
|
||||
loaded_keys = [
|
||||
k.replace(f"{module_name}.", "")
|
||||
for k in _loaded_keys
|
||||
if k.startswith(f"{module_name}.")
|
||||
]
|
||||
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
|
||||
module._is_hf_initialized = True
|
||||
|
||||
# Make sure we are able to load base models as well as derived models (with heads)
|
||||
def _load_pretrained_model(cls, resolved_archive_file):
|
||||
start_prefix = ""
|
||||
model_to_load = cls
|
||||
error_msgs = []
|
||||
|
@ -606,26 +537,17 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
|||
resolved_archive_file, desc="Loading checkpoint shards"
|
||||
)
|
||||
for shard_file in resolved_archive_file:
|
||||
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
|
||||
state_dict = torch.load(shard_file, map_location="cpu")
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
|
||||
error_msgs += cls._load_state_dict_into_model(
|
||||
model_to_load, state_dict, start_prefix
|
||||
)
|
||||
# force memory release
|
||||
del state_dict
|
||||
del state_dict # force memory release
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n"
|
||||
)
|
||||
print(
|
||||
f"All the weights of {cls.__class__.__name__} were initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
||||
f" was trained on, you can already use {cls.__class__.__name__} for predictions without further"
|
||||
" training."
|
||||
)
|
||||
return cls
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
2
demo.py
2
demo.py
|
@ -37,7 +37,7 @@ init_kwargs["name_or_path"] = pretrained_model_name_or_path
|
|||
tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
|
||||
|
||||
|
||||
glm = glm.from_pretrained(pretrained_model_name_or_path, config=config).half().cuda()
|
||||
glm = glm.from_pretrained(pretrained_model_name_or_path).half().cuda()
|
||||
glm = glm.eval()
|
||||
query = "colin"
|
||||
response, history = glm.chat(tokenizer, query, history=[])
|
||||
|
|
Loading…
Reference in New Issue