2023-12-21 16:53:47 +08:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import torch
|
|
|
|
from typing import List, Optional, Union, Dict
|
|
|
|
from sentencepiece import SentencePieceProcessor
|
|
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
class SPTokenizer:
|
|
|
|
def __init__(self, model_path: str):
|
|
|
|
# reload tokenizer
|
|
|
|
assert os.path.isfile(model_path), model_path
|
|
|
|
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
|
|
|
|
|
|
|
# BOS / EOS token IDs
|
|
|
|
self.n_words: int = self.sp_model.vocab_size()
|
|
|
|
self.bos_id: int = self.sp_model.bos_id()
|
|
|
|
self.eos_id: int = self.sp_model.eos_id()
|
|
|
|
self.pad_id: int = self.sp_model.unk_id()
|
|
|
|
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
|
|
|
|
2024-01-01 10:20:04 +08:00
|
|
|
special_tokens = [
|
|
|
|
"[MASK]",
|
|
|
|
"[gMASK]",
|
|
|
|
"[sMASK]",
|
|
|
|
"sop",
|
|
|
|
"eop",
|
|
|
|
"<|system|>",
|
|
|
|
"<|user|>",
|
|
|
|
"<|assistant|>",
|
|
|
|
"<|observation|>",
|
|
|
|
]
|
2023-12-21 16:53:47 +08:00
|
|
|
self.special_tokens = {}
|
|
|
|
self.index_special_tokens = {}
|
|
|
|
for token in special_tokens:
|
|
|
|
self.special_tokens[token] = self.n_words
|
|
|
|
self.index_special_tokens[self.n_words] = token
|
|
|
|
self.n_words += 1
|
|
|
|
|
|
|
|
def tokenize(self, s: str):
|
|
|
|
return self.sp_model.EncodeAsPieces(s)
|
|
|
|
|
|
|
|
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
|
|
|
assert type(s) is str
|
|
|
|
t = self.sp_model.encode(s)
|
|
|
|
if bos:
|
|
|
|
t = [self.bos_id] + t
|
|
|
|
if eos:
|
|
|
|
t = t + [self.eos_id]
|
|
|
|
return t
|
|
|
|
|
|
|
|
def decode(self, t: List[int]) -> str:
|
|
|
|
text, buffer = "", []
|
|
|
|
for token in t:
|
|
|
|
if token in self.index_special_tokens:
|
|
|
|
if buffer:
|
|
|
|
text += self.sp_model.decode(buffer)
|
|
|
|
buffer = []
|
|
|
|
text += self.index_special_tokens[token]
|
|
|
|
else:
|
|
|
|
buffer.append(token)
|
|
|
|
if buffer:
|
|
|
|
text += self.sp_model.decode(buffer)
|
|
|
|
return text
|
|
|
|
|
|
|
|
def decode_tokens(self, tokens: List[str]) -> str:
|
|
|
|
text = self.sp_model.DecodePieces(tokens)
|
|
|
|
return text
|
|
|
|
|
|
|
|
def convert_token_to_id(self, token):
|
2024-01-01 10:20:04 +08:00
|
|
|
"""Converts a token (str) in an id using the vocab."""
|
2023-12-21 16:53:47 +08:00
|
|
|
if token in self.special_tokens:
|
|
|
|
return self.special_tokens[token]
|
|
|
|
return self.sp_model.PieceToId(token)
|
|
|
|
|
|
|
|
def convert_id_to_token(self, index):
|
|
|
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
|
|
|
if index in self.index_special_tokens:
|
|
|
|
return self.index_special_tokens[index]
|
|
|
|
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
|
|
|
return ""
|
|
|
|
return self.sp_model.IdToPiece(index)
|
|
|
|
|
|
|
|
|
|
|
|
class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
|
|
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
|
|
|
|
|
|
|
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
|
|
|
|
|
|
|
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
|
|
|
|
self.name = "GLMTokenizer"
|
|
|
|
|
|
|
|
self.vocab_file = vocab_file
|
|
|
|
self.tokenizer = SPTokenizer(vocab_file)
|
|
|
|
self.special_tokens = {
|
|
|
|
"<bos>": self.tokenizer.bos_id,
|
|
|
|
"<eos>": self.tokenizer.eos_id,
|
2024-01-01 10:20:04 +08:00
|
|
|
"<pad>": self.tokenizer.pad_id,
|
2023-12-21 16:53:47 +08:00
|
|
|
}
|
|
|
|
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
|
|
|
|
|
|
|
|
def get_command(self, token):
|
|
|
|
if token in self.special_tokens:
|
|
|
|
return self.special_tokens[token]
|
|
|
|
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
|
|
|
return self.tokenizer.special_tokens[token]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def unk_token(self) -> str:
|
|
|
|
return "<unk>"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def pad_token(self) -> str:
|
|
|
|
return "<unk>"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def pad_token_id(self):
|
|
|
|
return self.get_command("<pad>")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def eos_token(self) -> str:
|
|
|
|
return "</s>"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def eos_token_id(self):
|
|
|
|
return self.get_command("<eos>")
|
|
|
|
|
|
|
|
@property
|
|
|
|
def vocab_size(self):
|
|
|
|
return self.tokenizer.n_words
|
|
|
|
|
|
|
|
def get_vocab(self):
|
2024-01-01 10:20:04 +08:00
|
|
|
"""Returns vocab as a dict"""
|
2023-12-21 16:53:47 +08:00
|
|
|
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
|
|
|
vocab.update(self.added_tokens_encoder)
|
|
|
|
return vocab
|
|
|
|
|
|
|
|
def _tokenize(self, text, **kwargs):
|
|
|
|
return self.tokenizer.tokenize(text)
|
|
|
|
|
|
|
|
def _convert_token_to_id(self, token):
|
2024-01-01 10:20:04 +08:00
|
|
|
"""Converts a token (str) in an id using the vocab."""
|
2023-12-21 16:53:47 +08:00
|
|
|
return self.tokenizer.convert_token_to_id(token)
|
|
|
|
|
|
|
|
def _convert_id_to_token(self, index):
|
|
|
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
|
|
|
return self.tokenizer.convert_id_to_token(index)
|
|
|
|
|
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
|
|
|
return self.tokenizer.decode_tokens(tokens)
|
|
|
|
|
|
|
|
def build_single_message(self, role, metadata, message):
|
|
|
|
assert role in ["system", "user", "assistant", "observation"], role
|
|
|
|
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
|
|
|
message_tokens = self.tokenizer.encode(message)
|
|
|
|
tokens = role_tokens + message_tokens
|
|
|
|
return tokens
|
|
|
|
|
|
|
|
def build_chat_input(self, query, history=None, role="user"):
|
|
|
|
if history is None:
|
|
|
|
history = []
|
|
|
|
input_ids = []
|
|
|
|
for item in history:
|
|
|
|
content = item["content"]
|
|
|
|
if item["role"] == "system" and "tools" in item:
|
|
|
|
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
|
|
|
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
|
|
|
input_ids.extend(self.build_single_message(role, "", query))
|
|
|
|
input_ids.extend([self.get_command("<|assistant|>")])
|
|
|
|
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|