From 8adae2130c0ca29018a2b996a57e265645a9688d Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 4 Jan 2024 19:56:30 +0800 Subject: [PATCH] Update chatglm train. --- chatglm/modeling_chatglm.py | 3 --- train.py => chatglm/train.py | 15 +++++++++------ tools/__init__.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) rename train.py => chatglm/train.py (76%) diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index 74cb5db..84eac7e 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -20,7 +20,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.generation import GenerationConfig from configuration_chatglm import ChatGLMConfig -from tools import show class RotaryEmbedding(nn.Module): @@ -67,7 +66,6 @@ class RMSNorm(torch.nn.Module): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - # show.DumpTensorToImage(self.weight, "RMSNorm_weight.png") return (self.weight * hidden_states).to(input_dtype) @@ -433,7 +431,6 @@ class ChatGLMModel(nn.Module): inputs_embeds = self.embedding(input_ids) rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - # show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1) rotary_pos_emb = rotary_pos_emb[position_ids] rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() diff --git a/train.py b/chatglm/train.py similarity index 76% rename from train.py rename to chatglm/train.py index e543bce..8c7ba84 100644 --- a/train.py +++ b/chatglm/train.py @@ -1,8 +1,12 @@ +import sys +sys.path.append("..") + import json import torch -from chatglm import ChatGLMForConditionalGeneration -from chatglm import ChatGLMTokenizer +from modeling_chatglm import ChatGLMForConditionalGeneration +from tokenization_chatglm import ChatGLMTokenizer +from modelscope import snapshot_download from tools import show @@ -12,8 +16,7 @@ seed = 4321 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - -pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b" +pretrained_model_name_or_path = snapshot_download("ZhipuAI/chatglm3-6b") config, kwargs = AutoConfig.from_pretrained( pretrained_model_name_or_path, return_unused_kwargs=True, @@ -24,7 +27,7 @@ config, kwargs = AutoConfig.from_pretrained( glm = ChatGLMForConditionalGeneration(config) -tokenizer_config_file = "./chatglm/tokenizer_config.json" +tokenizer_config_file = "./tokenizer_config.json" if tokenizer_config_file is not None: with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: init_kwargs = json.load(tokenizer_config_handle) @@ -32,7 +35,7 @@ if tokenizer_config_file is not None: init_kwargs.pop("tokenizer_file", None) saved_init_inputs = init_kwargs.pop("init_inputs", ()) init_inputs = saved_init_inputs -init_kwargs["vocab_file"] = "./chatglm/tokenizer.model" +init_kwargs["vocab_file"] = "./tokenizer.model" init_kwargs["added_tokens_file"] = None init_kwargs["special_tokens_map_file"] = None init_kwargs["tokenizer_file"] = None diff --git a/tools/__init__.py b/tools/__init__.py index 65d088a..d5e2687 100644 --- a/tools/__init__.py +++ b/tools/__init__.py @@ -1 +1 @@ -import show \ No newline at end of file +from tools import show \ No newline at end of file