Update chatglm train.
This commit is contained in:
parent
9deb809a88
commit
8adae2130c
|
@ -20,7 +20,6 @@ from transformers.configuration_utils import PretrainedConfig
|
|||
from transformers.generation import GenerationConfig
|
||||
|
||||
from configuration_chatglm import ChatGLMConfig
|
||||
from tools import show
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
@ -67,7 +66,6 @@ class RMSNorm(torch.nn.Module):
|
|||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
# show.DumpTensorToImage(self.weight, "RMSNorm_weight.png")
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
|
||||
|
@ -433,7 +431,6 @@ class ChatGLMModel(nn.Module):
|
|||
inputs_embeds = self.embedding(input_ids)
|
||||
|
||||
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
|
||||
# show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1)
|
||||
|
||||
rotary_pos_emb = rotary_pos_emb[position_ids]
|
||||
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
import sys
|
||||
sys.path.append("..")
|
||||
|
||||
import json
|
||||
import torch
|
||||
|
||||
from chatglm import ChatGLMForConditionalGeneration
|
||||
from chatglm import ChatGLMTokenizer
|
||||
from modeling_chatglm import ChatGLMForConditionalGeneration
|
||||
from tokenization_chatglm import ChatGLMTokenizer
|
||||
from modelscope import snapshot_download
|
||||
|
||||
from tools import show
|
||||
|
||||
|
@ -12,8 +16,7 @@ seed = 4321
|
|||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
|
||||
pretrained_model_name_or_path = snapshot_download("ZhipuAI/chatglm3-6b")
|
||||
config, kwargs = AutoConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
return_unused_kwargs=True,
|
||||
|
@ -24,7 +27,7 @@ config, kwargs = AutoConfig.from_pretrained(
|
|||
glm = ChatGLMForConditionalGeneration(config)
|
||||
|
||||
|
||||
tokenizer_config_file = "./chatglm/tokenizer_config.json"
|
||||
tokenizer_config_file = "./tokenizer_config.json"
|
||||
if tokenizer_config_file is not None:
|
||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
||||
init_kwargs = json.load(tokenizer_config_handle)
|
||||
|
@ -32,7 +35,7 @@ if tokenizer_config_file is not None:
|
|||
init_kwargs.pop("tokenizer_file", None)
|
||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||
init_inputs = saved_init_inputs
|
||||
init_kwargs["vocab_file"] = "./chatglm/tokenizer.model"
|
||||
init_kwargs["vocab_file"] = "./tokenizer.model"
|
||||
init_kwargs["added_tokens_file"] = None
|
||||
init_kwargs["special_tokens_map_file"] = None
|
||||
init_kwargs["tokenizer_file"] = None
|
|
@ -1 +1 @@
|
|||
import show
|
||||
from tools import show
|
Loading…
Reference in New Issue