Update code.
This commit is contained in:
parent
fa7078b72d
commit
0bc7bc90b1
25
Readme.md
25
Readme.md
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
|
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
|
||||||
|
|
||||||
for
|
for:
|
||||||
input_ids -> [1, 6] 1:batch_num 6:sequence_length
|
input_ids -> [1, 6] 1:batch_num 6:sequence_length
|
||||||
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
|
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
|
||||||
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
|
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
|
||||||
|
@ -40,19 +40,6 @@ Linear(hidden_states) no bias -> [6, 1, 27392]
|
||||||
silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
|
silu (x) = [6, 1, 13696] * sigmoid([6, 1, 13696])
|
||||||
Linear(intermediate_parallel) no bias -> [6, 1, 4096]
|
Linear(intermediate_parallel) no bias -> [6, 1, 4096]
|
||||||
|
|
||||||
## core_attention
|
|
||||||
|
|
||||||
query_layer=query_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
|
||||||
key_layer=key_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
|
||||||
value_layer=value_layer.permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
|
||||||
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128]
|
|
||||||
softmax(QK^T/sqrt(in_dim))V
|
|
||||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
||||||
att = F.softmax(att, dim=-1)
|
|
||||||
y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
||||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
|
||||||
context_layer = context_layer.reshape() -> [6, 1, 4096]
|
|
||||||
|
|
||||||
## self_attention
|
## self_attention
|
||||||
|
|
||||||
hidden_states: [s, b, h]
|
hidden_states: [s, b, h]
|
||||||
|
@ -74,7 +61,15 @@ value_layer = value_layer.unsqueeze(-2) -> [6, 1, 2, 1, 128]
|
||||||
value_layer = value_layer.expand -> [6, 1, 2, 16, 128]
|
value_layer = value_layer.expand -> [6, 1, 2, 16, 128]
|
||||||
value_layer = value_layer.contiguous().view -> [6, 1, 32, 128]
|
value_layer = value_layer.contiguous().view -> [6, 1, 32, 128]
|
||||||
|
|
||||||
context_layer = self.core_attention(query_layer, key_layer, value_layer) -> [6, 1, 4096]
|
query_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||||
|
key_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||||
|
value_layer permute(1, 2, 0, 3) -> [1, 32, 6, 128]
|
||||||
|
context_layer = scaled_dot_product_attention(query_layer, key_layer, value_layer) -> [1, 32, 6, 128]
|
||||||
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||||
|
att = F.softmax(att, dim=-1)
|
||||||
|
y = att @ v -> (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||||
|
context_layer = context_layer.permute(2, 0, 1, 3).reshape() -> [6, 1, 4096]
|
||||||
|
|
||||||
return Linear(context_layer) -> [6, 1, 4096]
|
return Linear(context_layer) -> [6, 1, 4096]
|
||||||
|
|
||||||
## GLMBlock
|
## GLMBlock
|
||||||
|
|
|
@ -22,8 +22,6 @@ from transformers.generation import GenerationConfig
|
||||||
from chatglm import ChatGLMConfig
|
from chatglm import ChatGLMConfig
|
||||||
from tools import show
|
from tools import show
|
||||||
|
|
||||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim: int, original_impl=False, device=None, dtype=None):
|
def __init__(self, dim: int, original_impl=False, device=None, dtype=None):
|
||||||
|
@ -487,60 +485,26 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
self.generation_config = GenerationConfig.from_model_config(config)
|
self.generation_config = GenerationConfig.from_model_config(config)
|
||||||
|
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]
|
||||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
|
||||||
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
state_dict = kwargs.pop("state_dict", None)
|
load_in_8bit = False
|
||||||
_ = kwargs.pop("mirror", None)
|
load_in_4bit = False
|
||||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
|
||||||
load_in_4bit = kwargs.pop("load_in_4bit", False)
|
|
||||||
subfolder = kwargs.pop("subfolder", "")
|
|
||||||
|
|
||||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
|
||||||
# index of the files.
|
|
||||||
sharded_metadata = None
|
|
||||||
|
|
||||||
# if pretrained_model_name_or_path is not None:
|
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
archive_file = os.path.join(
|
resolved_archive_file = os.path.join(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path, "pytorch_model.bin.index.json"
|
||||||
subfolder,
|
|
||||||
WEIGHTS_INDEX_NAME,
|
|
||||||
)
|
)
|
||||||
print(f"loading weights file {archive_file}")
|
print(f"loading weights file {resolved_archive_file}")
|
||||||
resolved_archive_file = archive_file
|
|
||||||
|
|
||||||
with open(resolved_archive_file, "r") as f:
|
with open(resolved_archive_file, "r") as f:
|
||||||
index = json.loads(f.read())
|
index = json.loads(f.read())
|
||||||
|
|
||||||
shard_filenames = sorted(set(index["weight_map"].values()))
|
shard_filenames = sorted(set(index["weight_map"].values()))
|
||||||
sharded_metadata = index["metadata"]
|
|
||||||
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
|
|
||||||
sharded_metadata["weight_map"] = index["weight_map"].copy()
|
|
||||||
|
|
||||||
resolved_archive_file = [
|
resolved_archive_file = [
|
||||||
os.path.join(pretrained_model_name_or_path, subfolder, f)
|
os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames
|
||||||
for f in shard_filenames
|
|
||||||
]
|
]
|
||||||
|
model = cls._load_pretrained_model(resolved_archive_file)
|
||||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
|
||||||
config.name_or_path = pretrained_model_name_or_path
|
|
||||||
config = copy.deepcopy(config)
|
|
||||||
# make sure we use the model's config since the __init__ call might have copied it
|
|
||||||
config = cls.config
|
|
||||||
# restore default dtype
|
|
||||||
model = cls._load_pretrained_model(
|
|
||||||
state_dict,
|
|
||||||
loaded_state_dict_keys,
|
|
||||||
resolved_archive_file,
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
)
|
|
||||||
model.is_loaded_in_4bit = load_in_4bit
|
model.is_loaded_in_4bit = load_in_4bit
|
||||||
model.is_loaded_in_8bit = load_in_8bit
|
model.is_loaded_in_8bit = load_in_8bit
|
||||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
model.eval() # Set model in evaluation mode to deactivate DropOut modules by default
|
||||||
model.eval()
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
|
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
|
||||||
|
@ -564,40 +528,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
del state_dict
|
del state_dict
|
||||||
return error_msgs
|
return error_msgs
|
||||||
|
|
||||||
def _load_pretrained_model(
|
def _load_pretrained_model(cls, resolved_archive_file):
|
||||||
cls,
|
|
||||||
state_dict,
|
|
||||||
loaded_keys,
|
|
||||||
resolved_archive_file,
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
):
|
|
||||||
# Retrieve missing & unexpected_keys
|
|
||||||
model_state_dict = cls.state_dict()
|
|
||||||
expected_keys = list(model_state_dict.keys())
|
|
||||||
loaded_keys = [key for key in loaded_keys]
|
|
||||||
|
|
||||||
unexpected_keys = set(loaded_keys) - set(expected_keys)
|
|
||||||
model_buffers = {n for n, _ in cls.named_buffers()}
|
|
||||||
unexpected_keys = list(unexpected_keys - model_buffers)
|
|
||||||
|
|
||||||
ptrs = collections.defaultdict(list)
|
|
||||||
for name, tensor in cls.state_dict().items():
|
|
||||||
ptrs[(tensor.device, storage_ptr(tensor), storage_size(tensor))].append(
|
|
||||||
name
|
|
||||||
)
|
|
||||||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
|
||||||
|
|
||||||
_loaded_keys = loaded_keys
|
|
||||||
for module_name, module in cls.named_modules():
|
|
||||||
loaded_keys = [
|
|
||||||
k.replace(f"{module_name}.", "")
|
|
||||||
for k in _loaded_keys
|
|
||||||
if k.startswith(f"{module_name}.")
|
|
||||||
]
|
|
||||||
if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0:
|
|
||||||
module._is_hf_initialized = True
|
|
||||||
|
|
||||||
# Make sure we are able to load base models as well as derived models (with heads)
|
|
||||||
start_prefix = ""
|
start_prefix = ""
|
||||||
model_to_load = cls
|
model_to_load = cls
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
|
@ -606,26 +537,17 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
resolved_archive_file, desc="Loading checkpoint shards"
|
resolved_archive_file, desc="Loading checkpoint shards"
|
||||||
)
|
)
|
||||||
for shard_file in resolved_archive_file:
|
for shard_file in resolved_archive_file:
|
||||||
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
|
|
||||||
state_dict = torch.load(shard_file, map_location="cpu")
|
state_dict = torch.load(shard_file, map_location="cpu")
|
||||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
|
||||||
# matching the weights in the model.
|
|
||||||
error_msgs += cls._load_state_dict_into_model(
|
error_msgs += cls._load_state_dict_into_model(
|
||||||
model_to_load, state_dict, start_prefix
|
model_to_load, state_dict, start_prefix
|
||||||
)
|
)
|
||||||
# force memory release
|
del state_dict # force memory release
|
||||||
del state_dict
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n"
|
f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n"
|
||||||
)
|
)
|
||||||
print(
|
|
||||||
f"All the weights of {cls.__class__.__name__} were initialized from the model checkpoint at"
|
|
||||||
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
|
||||||
f" was trained on, you can already use {cls.__class__.__name__} for predictions without further"
|
|
||||||
" training."
|
|
||||||
)
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|
2
demo.py
2
demo.py
|
@ -37,7 +37,7 @@ init_kwargs["name_or_path"] = pretrained_model_name_or_path
|
||||||
tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
|
tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
|
||||||
|
|
||||||
|
|
||||||
glm = glm.from_pretrained(pretrained_model_name_or_path, config=config).half().cuda()
|
glm = glm.from_pretrained(pretrained_model_name_or_path).half().cuda()
|
||||||
glm = glm.eval()
|
glm = glm.eval()
|
||||||
query = "colin"
|
query = "colin"
|
||||||
response, history = glm.chat(tokenizer, query, history=[])
|
response, history = glm.chat(tokenizer, query, history=[])
|
||||||
|
|
Loading…
Reference in New Issue