Update more dump.
|
@ -558,9 +558,10 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
|||
|
||||
outputs = self.sample(
|
||||
input_ids,
|
||||
pad_token_id=generation_config.pad_token_id,
|
||||
eos_token_id=generation_config.eos_token_id,
|
||||
output_hidden_states=generation_config.output_hidden_states,
|
||||
generation_config.pad_token_id,
|
||||
generation_config.eos_token_id,
|
||||
generation_config.output_hidden_states,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
|
||||
|
@ -574,6 +575,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
|||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
tokenizer=None,
|
||||
):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
|
@ -582,7 +584,7 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
|||
isFinished = torch.zeros(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
token_count = 0
|
||||
# token_count = 0
|
||||
while True:
|
||||
input_ids_in = input_ids
|
||||
batch_size, seq_length = input_ids_in.shape
|
||||
|
@ -599,11 +601,13 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
|||
)
|
||||
next_token_logits = logits[:, -1, :]
|
||||
probs = nn.functional.softmax(next_token_logits, dim=-1)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# show.DumpTensorToImage(next_token_logits[0], "generated/next_tokens"+str(token_count)+".png")
|
||||
# response = tokenizer.decode(next_tokens)
|
||||
# name = "generated/next_tokens" + str(token_count) + "_" + response + "_.png"
|
||||
# show.DumpTensorToImage(next_token_logits[0], name)
|
||||
# token_count = token_count + 1
|
||||
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
# finished sentences should add a padding token to next
|
||||
pad_token = pad_token_id * isFinished
|
||||
next_tokens = next_tokens * (1 - isFinished) + pad_token
|
||||
|
|
4
demo.py
|
@ -8,7 +8,7 @@ from tools import show
|
|||
|
||||
from transformers import AutoConfig
|
||||
|
||||
seed = 1234
|
||||
seed = 4321
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
@ -45,7 +45,7 @@ glm = glm.eval()
|
|||
query = "你好"
|
||||
response, history = glm.chat(tokenizer, query, history=[])
|
||||
print(response)
|
||||
if response[1:] != " 你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题":
|
||||
if response[1:] != " 你好!有什么可以帮助您的吗":
|
||||
raise ()
|
||||
|
||||
# query = "colin"
|
||||
|
|
Before Width: | Height: | Size: 50 KiB After Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 52 KiB |
After Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 49 KiB |
Before Width: | Height: | Size: 48 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 47 KiB |
Before Width: | Height: | Size: 47 KiB |
Before Width: | Height: | Size: 48 KiB |
Before Width: | Height: | Size: 48 KiB |
Before Width: | Height: | Size: 46 KiB After Width: | Height: | Size: 46 KiB |
Before Width: | Height: | Size: 51 KiB |
Before Width: | Height: | Size: 51 KiB |
Before Width: | Height: | Size: 51 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 53 KiB After Width: | Height: | Size: 53 KiB |
Before Width: | Height: | Size: 50 KiB After Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 52 KiB After Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 51 KiB |
After Width: | Height: | Size: 53 KiB |
Before Width: | Height: | Size: 51 KiB |
After Width: | Height: | Size: 52 KiB |
Before Width: | Height: | Size: 51 KiB |
After Width: | Height: | Size: 51 KiB |
Before Width: | Height: | Size: 50 KiB |
After Width: | Height: | Size: 51 KiB |
Before Width: | Height: | Size: 52 KiB |
After Width: | Height: | Size: 49 KiB |
|
@ -0,0 +1,37 @@
|
|||
import json
|
||||
import torch
|
||||
from tools import show
|
||||
|
||||
from chatglm import ChatGLMTokenizer
|
||||
|
||||
pretrained_model_name_or_path = "../ZhipuAI/chatglm3-6b"
|
||||
|
||||
|
||||
tokenizer_config_file = "./chatglm/tokenizer_config.json"
|
||||
if tokenizer_config_file is not None:
|
||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
||||
init_kwargs = json.load(tokenizer_config_handle)
|
||||
init_kwargs.pop("tokenizer_class", None)
|
||||
init_kwargs.pop("tokenizer_file", None)
|
||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||
init_inputs = saved_init_inputs
|
||||
init_kwargs["vocab_file"] = "./chatglm/tokenizer.model"
|
||||
init_kwargs["added_tokens_file"] = None
|
||||
init_kwargs["special_tokens_map_file"] = None
|
||||
init_kwargs["tokenizer_file"] = None
|
||||
init_kwargs["name_or_path"] = pretrained_model_name_or_path
|
||||
tokenizer = ChatGLMTokenizer(*init_inputs, **init_kwargs)
|
||||
|
||||
|
||||
aa = tokenizer.build_chat_input("骉")
|
||||
ab = tokenizer.encode("骉")
|
||||
a = tokenizer.decode([236,173,140])
|
||||
|
||||
|
||||
|
||||
token = []
|
||||
for i in range(64798):
|
||||
token.append(str(i) + " : " + tokenizer.decode(i))
|
||||
show.DumpListToFile(token, "generated/token.log")
|
||||
|
||||
|
|
@ -41,8 +41,17 @@ def DumpTensorToLog(tensor, name="log"):
|
|||
f.writelines("%s" % d + os.linesep)
|
||||
f.close()
|
||||
|
||||
|
||||
def DumpTensorToFile(tensor, name="tensor.pt"):
|
||||
torch.save(tensor.cpu(), name)
|
||||
|
||||
|
||||
def LoadTensorToFile(name="tensor.pt"):
|
||||
return torch.load(name)
|
||||
|
||||
|
||||
def DumpListToFile(list, name="list"):
|
||||
f = open(name, "w")
|
||||
for d in list:
|
||||
f.writelines("%s" % d + os.linesep)
|
||||
f.close()
|
||||
|
|