Witllm/chatglm/modeling_chatglm.py

643 lines
23 KiB
Python
Raw Normal View History

2023-12-21 16:53:47 +08:00
import collections
import math
import copy
import os
import gc
import json
2023-12-25 16:22:45 +08:00
import hashlib
2023-12-21 16:53:47 +08:00
import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from torch import nn
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Dict, Any
from tqdm import auto as tqdm_lib
from safetensors.torch import storage_ptr, storage_size
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig
2024-01-03 20:26:26 +08:00
from configuration_chatglm import ChatGLMConfig
2023-12-21 20:50:10 +08:00
from tools import show
2023-12-21 16:53:47 +08:00
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, original_impl=False, device=None, dtype=None):
super().__init__()
2024-01-03 20:26:26 +08:00
inv_freq = 1.0 / (
10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)
)
2023-12-21 16:53:47 +08:00
self.register_buffer("inv_freq", inv_freq)
self.dim = dim
self.original_impl = original_impl
def forward(self, max_seq_len: int, base: int = 10000):
dtype = self.inv_freq.dtype
device = self.inv_freq.device
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
2024-01-03 20:26:26 +08:00
theta = 1.0 / (
base
** (
torch.arange(0, self.dim, 2, dtype=torch.float, device=device)
/ self.dim
)
)
2023-12-21 16:53:47 +08:00
# Create position indexes `[0, 1, ..., max_seq_len - 1]`
seq_idx = torch.arange(max_seq_len, dtype=torch.float, device=device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
# this is to mimic the behaviour of complex32, else we will get different results
if dtype in (torch.float16, torch.bfloat16, torch.int8):
cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
return cache
class RMSNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
super().__init__()
2024-01-03 20:26:26 +08:00
self.weight = torch.nn.Parameter(
torch.empty(normalized_shape, device=device, dtype=dtype)
)
2023-12-21 16:53:47 +08:00
self.eps = eps
def forward(self, hidden_states: torch.Tensor):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
2023-12-22 20:01:09 +08:00
# show.DumpTensorToImage(self.weight, "RMSNorm_weight.png")
2023-12-21 16:53:47 +08:00
return (self.weight * hidden_states).to(input_dtype)
class CoreAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number):
super(CoreAttention, self).__init__()
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
self.layer_number = max(1, layer_number)
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
self.hidden_size_per_partition = projection_size
2024-01-03 20:26:26 +08:00
self.hidden_size_per_attention_head = (
projection_size // config.num_attention_heads
)
2023-12-21 16:53:47 +08:00
self.num_attention_heads_per_partition = config.num_attention_heads
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
coeff = self.layer_number
self.norm_factor *= coeff
self.coeff = coeff
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
2023-12-22 11:39:06 +08:00
def forward(self, query_layer, key_layer, value_layer):
2024-01-03 20:26:26 +08:00
query_layer, key_layer, value_layer = [
k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
]
2023-12-22 11:39:06 +08:00
if query_layer.shape[2] == key_layer.shape[2]:
2023-12-21 16:53:47 +08:00
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, is_causal=True
)
context_layer = context_layer.permute(2, 0, 1, 3)
2024-01-03 20:26:26 +08:00
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_partition,
)
2023-12-21 16:53:47 +08:00
context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer
class SelfAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
super(SelfAttention, self).__init__()
self.layer_number = max(1, layer_number)
self.projection_size = config.kv_channels * config.num_attention_heads
2024-01-03 20:26:26 +08:00
self.hidden_size_per_attention_head = (
self.projection_size // config.num_attention_heads
)
2023-12-21 16:53:47 +08:00
self.num_attention_heads_per_partition = config.num_attention_heads
self.multi_query_attention = config.multi_query_attention
self.qkv_hidden_size = 3 * self.projection_size
self.num_multi_query_groups_per_partition = config.multi_query_group_num
self.qkv_hidden_size = (
2024-01-03 20:26:26 +08:00
self.projection_size
+ 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
2023-12-21 16:53:47 +08:00
)
self.query_key_value = nn.Linear(
config.hidden_size,
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device,
dtype=config.torch_dtype,
)
self.core_attention = CoreAttention(config, self.layer_number)
self.dense = nn.Linear(
self.projection_size,
config.hidden_size,
bias=config.add_bias_linear,
device=device,
dtype=config.torch_dtype,
)
2023-12-25 16:22:45 +08:00
def apply_rotary_pos_emb(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
2023-12-21 16:53:47 +08:00
# x: [sq, b, np, hn]
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
2023-12-25 16:22:45 +08:00
if rope.size(0) != sq:
raise ("Error rotary_pos_emb size")
x_rope = x[..., : hn // 2]
x_pass = x[..., hn // 2 :]
x_rope = x_rope.reshape(sq, -1, np, hn // 4, 1, 2)
rope = rope.view(sq, -1, 1, hn // 4, 1, 2)
roped1 = x_rope[..., 0] * rope[..., 0] - x_rope[..., 1] * rope[..., 1]
roped2 = x_rope[..., 1] * rope[..., 0] + x_rope[..., 0] * rope[..., 1]
x_out = torch.cat((roped1, roped2), -1)
x_out = x_out.flatten(3)
return torch.cat((x_out, x_pass), dim=-1)
2023-12-21 16:53:47 +08:00
2023-12-22 18:01:57 +08:00
def forward(self, hidden_states, rotary_pos_emb):
2023-12-21 16:53:47 +08:00
# hidden_states: [sq, b, h]
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
2024-01-03 20:26:26 +08:00
self.num_attention_heads_per_partition
* self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition
* self.hidden_size_per_attention_head,
self.num_multi_query_groups_per_partition
* self.hidden_size_per_attention_head,
2023-12-21 16:53:47 +08:00
],
dim=-1,
)
query_layer = query_layer.view(
query_layer.size()[:-1]
+ (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
)
key_layer = key_layer.view(
key_layer.size()[:-1]
+ (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
)
)
value_layer = value_layer.view(
value_layer.size()[:-1]
+ (
self.num_multi_query_groups_per_partition,
self.hidden_size_per_attention_head,
)
)
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
query_layer = self.apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = self.apply_rotary_pos_emb(key_layer, rotary_pos_emb)
key_layer = key_layer.unsqueeze(-2)
key_layer = key_layer.expand(
-1,
-1,
-1,
2024-01-03 20:26:26 +08:00
self.num_attention_heads_per_partition
// self.num_multi_query_groups_per_partition,
2023-12-21 16:53:47 +08:00
-1,
)
key_layer = key_layer.contiguous().view(
key_layer.size()[:2]
+ (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
)
value_layer = value_layer.unsqueeze(-2)
value_layer = value_layer.expand(
-1,
-1,
-1,
2024-01-03 20:26:26 +08:00
self.num_attention_heads_per_partition
// self.num_multi_query_groups_per_partition,
2023-12-21 16:53:47 +08:00
-1,
)
value_layer = value_layer.contiguous().view(
value_layer.size()[:2]
+ (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
)
2023-12-22 11:39:06 +08:00
context_layer = self.core_attention(query_layer, key_layer, value_layer)
2023-12-25 22:53:53 +08:00
output = self.dense(context_layer) # [sq, b, h]
2023-12-22 18:01:57 +08:00
return output
2023-12-21 16:53:47 +08:00
class MLP(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, device=None):
super(MLP, self).__init__()
self.add_bias = config.add_bias_linear
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = nn.Linear(
config.hidden_size,
config.ffn_hidden_size * 2,
bias=self.add_bias,
device=device,
dtype=config.torch_dtype,
)
2024-01-03 20:26:26 +08:00
2023-12-21 16:53:47 +08:00
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
2024-01-03 20:26:26 +08:00
2023-12-21 16:53:47 +08:00
self.activation_func = swiglu
self.dense_4h_to_h = nn.Linear(
config.ffn_hidden_size,
config.hidden_size,
bias=self.add_bias,
device=device,
dtype=config.torch_dtype,
)
def forward(self, hidden_states):
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
output = self.dense_4h_to_h(intermediate_parallel)
return output
class GLMBlock(torch.nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(self, config: ChatGLMConfig, layer_number, device=None):
super(GLMBlock, self).__init__()
self.layer_number = layer_number
2024-01-03 20:26:26 +08:00
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)
2023-12-21 16:53:47 +08:00
self.fp32_residual_connection = config.fp32_residual_connection
LayerNormFunc = RMSNorm
# Layernorm on the input data.
self.input_layernorm = LayerNormFunc(
config.hidden_size,
eps=config.layernorm_epsilon,
device=device,
dtype=config.torch_dtype,
)
# Self attention.
self.self_attention = SelfAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc(
config.hidden_size,
eps=config.layernorm_epsilon,
device=device,
dtype=config.torch_dtype,
)
self.mlp = MLP(config, device=device)
2023-12-22 18:01:57 +08:00
def forward(self, hidden_states, rotary_pos_emb):
2023-12-21 16:53:47 +08:00
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
2023-12-22 18:01:57 +08:00
attention_output = self.self_attention(layernorm_output, rotary_pos_emb)
2023-12-21 16:53:47 +08:00
residual = hidden_states
2024-01-03 20:26:26 +08:00
layernorm_input = torch.nn.functional.dropout(
attention_output, p=self.hidden_dropout, training=self.training
)
2023-12-21 16:53:47 +08:00
layernorm_input = residual + layernorm_input
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
residual = layernorm_input
2024-01-03 20:26:26 +08:00
output = torch.nn.functional.dropout(
mlp_output, p=self.hidden_dropout, training=self.training
)
2023-12-21 16:53:47 +08:00
output = residual + output
2023-12-22 18:01:57 +08:00
return output
2023-12-21 16:53:47 +08:00
class GLMTransformer(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, device=None):
super(GLMTransformer, self).__init__()
self.fp32_residual_connection = config.fp32_residual_connection
self.post_layer_norm = config.post_layer_norm
self.num_layers = config.num_layers
self.layers = []
for i in range(self.num_layers):
self.layers.append(GLMBlock(config, i + 1, device=device))
self.layers = torch.nn.ModuleList(self.layers)
self.final_layernorm = RMSNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
device=device,
dtype=config.torch_dtype,
)
2023-12-22 18:01:57 +08:00
def forward(self, hidden_states, rotary_pos_emb):
2023-12-21 16:53:47 +08:00
for index in range(self.num_layers):
layer = self.layers[index]
2023-12-22 18:01:57 +08:00
hidden_states = layer(hidden_states, rotary_pos_emb)
2023-12-21 16:53:47 +08:00
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
class Embedding(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, device=None):
super(Embedding, self).__init__()
self.hidden_size = config.hidden_size
self.word_embeddings = nn.Embedding(
config.padded_vocab_size,
self.hidden_size,
dtype=config.torch_dtype,
device=device,
)
def forward(self, input_ids):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
embeddings = words_embeddings
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
return embeddings
class ChatGLMModel(nn.Module):
def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
super().__init__()
init_method = skip_init
init_kwargs = {}
if device is not None:
init_kwargs["device"] = device
self.embedding = init_method(Embedding, config, **init_kwargs)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.config = config
# Rotary positional embeddings
self.seq_length = config.seq_length
rotary_dim = (
2024-01-03 20:26:26 +08:00
config.hidden_size // config.num_attention_heads
if config.kv_channels is None
else config.kv_channels
2023-12-21 16:53:47 +08:00
)
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2,
original_impl=config.original_rope,
device=device,
dtype=config.torch_dtype,
)
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
self.output_layer = init_method(
nn.Linear,
config.hidden_size,
config.padded_vocab_size,
bias=False,
dtype=config.torch_dtype,
**init_kwargs,
)
def forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
2023-12-27 19:58:52 +08:00
tokenizer=None,
2023-12-21 16:53:47 +08:00
):
output_hidden_states = (
2024-01-03 20:26:26 +08:00
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
2023-12-21 16:53:47 +08:00
)
2023-12-22 11:39:06 +08:00
inputs_embeds = self.embedding(input_ids)
2023-12-21 16:53:47 +08:00
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
2023-12-21 20:50:10 +08:00
# show.DumpTensorToImage(rotary_pos_emb[:, :, 0], "rotary_pos_emb.png", scale=0.1)
2023-12-21 16:53:47 +08:00
2023-12-22 11:39:06 +08:00
rotary_pos_emb = rotary_pos_emb[position_ids]
2023-12-21 16:53:47 +08:00
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
2023-12-27 19:58:52 +08:00
hidden_states_en = self.encoder(inputs_embeds, rotary_pos_emb)
hidden_states = hidden_states_en[-1:]
2023-12-21 16:53:47 +08:00
lm_logits = self.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous()
2023-12-27 19:58:52 +08:00
next_token_logits = lm_logits[:, -1, :]
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
2024-01-03 20:26:26 +08:00
return probs, next_tokens
2023-12-21 16:53:47 +08:00
class ChatGLMForConditionalGeneration(nn.Module):
def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
super().__init__()
self.max_sequence_length = config.max_length
self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
self.config = config
self.main_input_name = "input_ids"
self.config = config
self.name_or_path = config.name_or_path
self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config)
2024-01-03 20:26:26 +08:00
def from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]]
):
2023-12-25 17:26:19 +08:00
load_in_8bit = False
load_in_4bit = False
2023-12-21 16:53:47 +08:00
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
2024-01-03 20:26:26 +08:00
resolved_archive_file = os.path.join(
pretrained_model_name_or_path, "pytorch_model.bin.index.json"
)
2023-12-25 17:26:19 +08:00
print(f"loading weights file {resolved_archive_file}")
2023-12-21 16:53:47 +08:00
with open(resolved_archive_file, "r") as f:
index = json.loads(f.read())
shard_filenames = sorted(set(index["weight_map"].values()))
2024-01-03 20:26:26 +08:00
resolved_archive_file = [
os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames
]
2023-12-25 17:26:19 +08:00
model = cls._load_pretrained_model(resolved_archive_file)
2023-12-21 16:53:47 +08:00
model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit
return model
def _load_state_dict_into_model(self, model_to_load, state_dict, start_prefix):
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
del state_dict
return error_msgs
2023-12-25 17:26:19 +08:00
def _load_pretrained_model(cls, resolved_archive_file):
2023-12-21 16:53:47 +08:00
start_prefix = ""
model_to_load = cls
error_msgs = []
if len(resolved_archive_file) > 1:
2024-01-03 20:26:26 +08:00
resolved_archive_file = tqdm_lib.tqdm(
resolved_archive_file, desc="Loading checkpoint shards"
)
2023-12-21 16:53:47 +08:00
for shard_file in resolved_archive_file:
state_dict = torch.load(shard_file, map_location="cpu")
2023-12-25 17:26:19 +08:00
2024-01-03 20:26:26 +08:00
error_msgs += cls._load_state_dict_into_model(
model_to_load, state_dict, start_prefix
)
2023-12-25 17:26:19 +08:00
del state_dict # force memory release
2023-12-21 16:53:47 +08:00
gc.collect()
2024-01-03 20:26:26 +08:00
print(
f"All model checkpoint weights were used when initializing {cls.__class__.__name__}.\n"
)
2023-12-21 16:53:47 +08:00
return cls
@torch.inference_mode()
def chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
role: str = "user",
):
if history is None:
history = []
inputs = tokenizer.build_chat_input(query, history=history, role=role)
inputs = inputs.to(next(self.parameters()).device)
generation_config = copy.deepcopy(self.generation_config)
inputs_tensor = inputs["input_ids"]
2024-01-03 20:26:26 +08:00
input_ids = inputs_tensor.repeat_interleave(
generation_config.num_return_sequences, dim=0
)
2023-12-21 16:53:47 +08:00
outputs = self.sample(
input_ids,
2023-12-26 14:08:02 +08:00
generation_config.pad_token_id,
generation_config.eos_token_id,
generation_config.output_hidden_states,
tokenizer,
2023-12-21 16:53:47 +08:00
)
2023-12-21 19:52:19 +08:00
2023-12-27 19:58:52 +08:00
outputs = outputs.tolist()[0][:]
2023-12-21 16:53:47 +08:00
response = tokenizer.decode(outputs)
history.append({"role": role, "content": query})
return response, history
def sample(
self,
input_ids: torch.LongTensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
2023-12-22 18:01:57 +08:00
output_hidden_states: Optional[bool] = None,
2023-12-26 14:08:02 +08:00
tokenizer=None,
2023-12-21 16:53:47 +08:00
):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
2024-01-03 20:26:26 +08:00
isFinished = torch.zeros(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
2023-12-26 14:08:02 +08:00
# token_count = 0
2023-12-21 16:53:47 +08:00
while True:
input_ids_in = input_ids
batch_size, seq_length = input_ids_in.shape
position_ids_in = (
2024-01-03 20:26:26 +08:00
torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
.unsqueeze(0)
.repeat(batch_size, 1)
2023-12-21 16:53:47 +08:00
)
2023-12-22 18:01:57 +08:00
model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
2023-12-21 16:53:47 +08:00
2024-01-03 20:26:26 +08:00
probs, next_tokens = self.transformer(
2023-12-21 16:53:47 +08:00
**model_inputs,
output_hidden_states=output_hidden_states,
2023-12-27 19:58:52 +08:00
tokenizer=tokenizer,
2023-12-21 16:53:47 +08:00
)
2023-12-25 22:53:53 +08:00
2023-12-22 18:57:16 +08:00
# finished sentences should add a padding token to next
pad_token = pad_token_id * isFinished
next_tokens = next_tokens * (1 - isFinished) + pad_token
2023-12-21 16:53:47 +08:00
2023-12-22 18:57:16 +08:00
isFinished = isFinished | next_tokens.eq(eos_token_id_tensor)
if isFinished.min() == 1: # all batch is finish
2023-12-21 16:53:47 +08:00
break
2023-12-22 19:14:22 +08:00
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
2023-12-21 16:53:47 +08:00
return input_ids
2024-01-03 20:26:26 +08:00
def backward(
self,
tokenizer,
query: str,
):
inputs = tokenizer.build_chat_input(query, history=[], role="user")
inputs = inputs.to(next(self.parameters()).device)
generation_config = copy.deepcopy(self.generation_config)
inputs_tensor = inputs["input_ids"]
input_ids = inputs_tensor.repeat_interleave(
generation_config.num_return_sequences, dim=0
)
input_ids_in = input_ids
batch_size, seq_length = input_ids_in.shape
position_ids_in = (
torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
.unsqueeze(0)
.repeat(batch_size, 1)
)
model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
probs, next_tokens = self.transformer(
**model_inputs,
output_hidden_states=None,
tokenizer=tokenizer,
)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# probs_target = probs
# probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1
loss = probs[0, next_tokens]
loss.backward()
return loss