Format qwen model.
This commit is contained in:
parent
255a2ff71c
commit
611396b656
|
@ -38,7 +38,9 @@ from torch import nn
|
||||||
SUPPORT_CUDA = torch.cuda.is_available()
|
SUPPORT_CUDA = torch.cuda.is_available()
|
||||||
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
||||||
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
||||||
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
|
SUPPORT_TORCH2 = (
|
||||||
|
hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from configuration_qwen import QWenConfig
|
from configuration_qwen import QWenConfig
|
||||||
|
@ -70,6 +72,7 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
|
||||||
apply_rotary_emb_func = None
|
apply_rotary_emb_func = None
|
||||||
rms_norm = None
|
rms_norm = None
|
||||||
|
|
||||||
|
|
||||||
def quantize_cache_v(fdata, bits, qmax, qmin):
|
def quantize_cache_v(fdata, bits, qmax, qmin):
|
||||||
# b, s, head, h-dim->b, head, s, h-dim
|
# b, s, head, h-dim->b, head, s, h-dim
|
||||||
qtype = torch.uint8
|
qtype = torch.uint8
|
||||||
|
@ -85,17 +88,19 @@ def quantize_cache_v(fdata, bits, qmax, qmin):
|
||||||
qmin = qmin.to(device)
|
qmin = qmin.to(device)
|
||||||
scale = (fmax - fmin) / (qmax - qmin)
|
scale = (fmax - fmin) / (qmax - qmin)
|
||||||
zero = qmin - fmin / scale
|
zero = qmin - fmin / scale
|
||||||
scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
|
scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
|
||||||
zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous()
|
zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
|
||||||
# Quantize
|
# Quantize
|
||||||
res_data = fdata / scale + zero
|
res_data = fdata / scale + zero
|
||||||
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
|
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
|
||||||
return qdata.contiguous(), scale, zero
|
return qdata.contiguous(), scale, zero
|
||||||
|
|
||||||
|
|
||||||
def dequantize_cache_torch(qdata, scale, zero):
|
def dequantize_cache_torch(qdata, scale, zero):
|
||||||
data = scale * (qdata - zero)
|
data = scale * (qdata - zero)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class QWenAttention(nn.Module):
|
class QWenAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -138,12 +143,20 @@ class QWenAttention(nn.Module):
|
||||||
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
|
||||||
|
|
||||||
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
|
||||||
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
|
self.softmax_in_fp32 = (
|
||||||
self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
|
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
||||||
self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
|
)
|
||||||
|
self.use_cache_quantization = (
|
||||||
|
config.use_cache_quantization
|
||||||
|
if hasattr(config, "use_cache_quantization")
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
self.use_cache_kernel = (
|
||||||
|
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
||||||
|
)
|
||||||
cache_dtype = torch.float
|
cache_dtype = torch.float
|
||||||
if self.bf16:
|
if self.bf16:
|
||||||
cache_dtype=torch.bfloat16
|
cache_dtype = torch.bfloat16
|
||||||
elif config.fp16:
|
elif config.fp16:
|
||||||
cache_dtype = torch.float16
|
cache_dtype = torch.float16
|
||||||
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
||||||
|
@ -152,19 +165,25 @@ class QWenAttention(nn.Module):
|
||||||
if config.use_cache_quantization and config.use_cache_kernel:
|
if config.use_cache_quantization and config.use_cache_kernel:
|
||||||
# pre check if the support files existing
|
# pre check if the support files existing
|
||||||
module_root = pathlib.Path(__file__).parent
|
module_root = pathlib.Path(__file__).parent
|
||||||
src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu")
|
src_files = (
|
||||||
if any(not (module_root/src).is_file() for src in src_files):
|
"cache_autogptq_cuda_256.cpp",
|
||||||
|
"cache_autogptq_cuda_kernel_256.cu",
|
||||||
|
)
|
||||||
|
if any(not (module_root / src).is_file() for src in src_files):
|
||||||
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
|
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
|
||||||
self.cache_kernels = None
|
self.cache_kernels = None
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
from .cpp_kernels import cache_autogptq_cuda_256
|
from .cpp_kernels import cache_autogptq_cuda_256
|
||||||
|
|
||||||
self.cache_kernels = cache_autogptq_cuda_256
|
self.cache_kernels = cache_autogptq_cuda_256
|
||||||
except ImportError:
|
except ImportError:
|
||||||
warnings.warn("Failed to import KV cache kernels.")
|
warnings.warn("Failed to import KV cache kernels.")
|
||||||
self.cache_kernels = None
|
self.cache_kernels = None
|
||||||
|
|
||||||
def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None):
|
def _attn(
|
||||||
|
self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None
|
||||||
|
):
|
||||||
device = query.device
|
device = query.device
|
||||||
if self.use_cache_quantization:
|
if self.use_cache_quantization:
|
||||||
qk, qk_scale, qk_zero = key
|
qk, qk_scale, qk_zero = key
|
||||||
|
@ -172,11 +191,18 @@ class QWenAttention(nn.Module):
|
||||||
shape = query.shape[:-1] + (qk.shape[-2],)
|
shape = query.shape[:-1] + (qk.shape[-2],)
|
||||||
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
||||||
self.cache_kernels.vecquant8matmul_batched_faster_old(
|
self.cache_kernels.vecquant8matmul_batched_faster_old(
|
||||||
query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(),
|
query.contiguous()
|
||||||
|
if query.dtype == torch.float16
|
||||||
|
else query.to(torch.float16).contiguous(),
|
||||||
qk.transpose(-1, -2).contiguous(),
|
qk.transpose(-1, -2).contiguous(),
|
||||||
attn_weights,
|
attn_weights,
|
||||||
qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(),
|
qk_scale.contiguous()
|
||||||
qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous())
|
if qk_scale.dtype == torch.float16
|
||||||
|
else qk_scale.to(torch.float16).contiguous(),
|
||||||
|
qk_zero.contiguous()
|
||||||
|
if qk_zero.dtype == torch.float16
|
||||||
|
else qk_zero.to(torch.float16).contiguous(),
|
||||||
|
)
|
||||||
# attn_weights = attn_weights.to(query.dtype).contiguous()
|
# attn_weights = attn_weights.to(query.dtype).contiguous()
|
||||||
else:
|
else:
|
||||||
key = dequantize_cache_torch(qk, qk_scale, qk_zero)
|
key = dequantize_cache_torch(qk, qk_scale, qk_zero)
|
||||||
|
@ -189,7 +215,7 @@ class QWenAttention(nn.Module):
|
||||||
size_temp = value[0].size(-1)
|
size_temp = value[0].size(-1)
|
||||||
else:
|
else:
|
||||||
size_temp = value.size(-1)
|
size_temp = value.size(-1)
|
||||||
attn_weights = attn_weights / (size_temp ** 0.5)
|
attn_weights = attn_weights / (size_temp**0.5)
|
||||||
|
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
mask_value = torch.finfo(attn_weights.dtype).min
|
||||||
if causal_mask is not None:
|
if causal_mask is not None:
|
||||||
|
@ -217,11 +243,18 @@ class QWenAttention(nn.Module):
|
||||||
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
||||||
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
||||||
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
|
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
|
||||||
attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(),
|
attn_weights.contiguous()
|
||||||
|
if attn_weights.dtype == torch.float16
|
||||||
|
else attn_weights.to(torch.float16).contiguous(),
|
||||||
qv.contiguous(), # dtype: int32
|
qv.contiguous(), # dtype: int32
|
||||||
attn_output,
|
attn_output,
|
||||||
qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(),
|
qv_scale.contiguous()
|
||||||
qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous())
|
if qv_scale.dtype == torch.float16
|
||||||
|
else qv_scale.to(torch.float16).contiguous(),
|
||||||
|
qv_zero.contiguous()
|
||||||
|
if qv_zero.dtype == torch.float16
|
||||||
|
else qv_zero.to(torch.float16).contiguous(),
|
||||||
|
)
|
||||||
if attn_output.dtype != query.dtype:
|
if attn_output.dtype != query.dtype:
|
||||||
attn_output = attn_output.to(query.dtype)
|
attn_output = attn_output.to(query.dtype)
|
||||||
attn_weights = attn_weights.to(query.dtype)
|
attn_weights = attn_weights.to(query.dtype)
|
||||||
|
@ -283,21 +316,26 @@ class QWenAttention(nn.Module):
|
||||||
rotary_pos_emb = (rotary_pos_emb,) * 2
|
rotary_pos_emb = (rotary_pos_emb,) * 2
|
||||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||||
# Slice the pos emb for current inference
|
# Slice the pos emb for current inference
|
||||||
query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)]
|
query_list += [
|
||||||
key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)]
|
apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)
|
||||||
|
]
|
||||||
|
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
|
||||||
query = torch.cat(query_list, dim=0)
|
query = torch.cat(query_list, dim=0)
|
||||||
key = torch.cat(key_list, dim=0)
|
key = torch.cat(key_list, dim=0)
|
||||||
|
|
||||||
if self.use_cache_quantization:
|
if self.use_cache_quantization:
|
||||||
key = quantize_cache_v(key.permute(0, 2, 1, 3),
|
key = quantize_cache_v(
|
||||||
bits=8,
|
key.permute(0, 2, 1, 3),
|
||||||
qmin=self.cache_qmin,
|
bits=8,
|
||||||
qmax=self.cache_qmax)
|
qmin=self.cache_qmin,
|
||||||
value = quantize_cache_v(value.permute(0, 2, 1, 3),
|
qmax=self.cache_qmax,
|
||||||
bits=8,
|
)
|
||||||
qmin=self.cache_qmin,
|
value = quantize_cache_v(
|
||||||
qmax=self.cache_qmax)
|
value.permute(0, 2, 1, 3),
|
||||||
|
bits=8,
|
||||||
|
qmin=self.cache_qmin,
|
||||||
|
qmax=self.cache_qmax,
|
||||||
|
)
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past[0], layer_past[1]
|
past_key, past_value = layer_past[0], layer_past[1]
|
||||||
|
@ -305,12 +343,16 @@ class QWenAttention(nn.Module):
|
||||||
# use_cache_quantization:
|
# use_cache_quantization:
|
||||||
# present=((q_key,key_scale,key_zero_point),
|
# present=((q_key,key_scale,key_zero_point),
|
||||||
# (q_value,value_scale,value_zero_point))
|
# (q_value,value_scale,value_zero_point))
|
||||||
key = (torch.cat((past_key[0], key[0]), dim=2),
|
key = (
|
||||||
torch.cat((past_key[1], key[1]), dim=2),
|
torch.cat((past_key[0], key[0]), dim=2),
|
||||||
torch.cat((past_key[2], key[2]), dim=2))
|
torch.cat((past_key[1], key[1]), dim=2),
|
||||||
value = (torch.cat((past_value[0], value[0]), dim=2),
|
torch.cat((past_key[2], key[2]), dim=2),
|
||||||
torch.cat((past_value[1], value[1]), dim=2),
|
)
|
||||||
torch.cat((past_value[2], value[2]), dim=2))
|
value = (
|
||||||
|
torch.cat((past_value[0], value[0]), dim=2),
|
||||||
|
torch.cat((past_value[1], value[1]), dim=2),
|
||||||
|
torch.cat((past_value[2], value[2]), dim=2),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# not use_cache_quantization:
|
# not use_cache_quantization:
|
||||||
# present=(key,value)
|
# present=(key,value)
|
||||||
|
@ -347,11 +389,11 @@ class QWenAttention(nn.Module):
|
||||||
|
|
||||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.expand(
|
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
||||||
-1, -1, causal_mask.size(2), -1
|
|
||||||
)
|
|
||||||
if causal_mask is not None:
|
if causal_mask is not None:
|
||||||
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
|
attention_mask = attention_mask.masked_fill(
|
||||||
|
~causal_mask, torch.finfo(query.dtype).min
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attention_mask = causal_mask
|
attention_mask = causal_mask
|
||||||
attn_output = F.scaled_dot_product_attention(
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
@ -362,16 +404,16 @@ class QWenAttention(nn.Module):
|
||||||
attn_output, attn_weight = self._attn(
|
attn_output, attn_weight = self._attn(
|
||||||
query, key, value, causal_mask, attention_mask, head_mask
|
query, key, value, causal_mask, attention_mask, head_mask
|
||||||
)
|
)
|
||||||
context_layer = self._merge_heads(
|
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||||
attn_output, self.num_heads, self.head_dim
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.c_proj(context_layer)
|
attn_output = self.c_proj(context_layer)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, present)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||||
raise ValueError("Cannot output attentions while using scaled_dot_product_attention")
|
raise ValueError(
|
||||||
|
"Cannot output attentions while using scaled_dot_product_attention"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
outputs += (attn_weight,)
|
outputs += (attn_weight,)
|
||||||
|
|
||||||
|
@ -507,7 +549,11 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False
|
self.use_cache_quantization = (
|
||||||
|
self.config.use_cache_quantization
|
||||||
|
if hasattr(self.config, "use_cache_quantization")
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||||
|
@ -521,25 +567,14 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
self.rotary_ndims = None
|
self.rotary_ndims = None
|
||||||
else:
|
else:
|
||||||
assert config.rotary_pct < 1
|
assert config.rotary_pct < 1
|
||||||
self.rotary_ndims = int(
|
self.rotary_ndims = int(config.kv_channels * config.rotary_pct)
|
||||||
config.kv_channels * config.rotary_pct
|
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
|
||||||
)
|
|
||||||
dim = (
|
|
||||||
self.rotary_ndims
|
|
||||||
if self.rotary_ndims is not None
|
|
||||||
else config.kv_channels
|
|
||||||
)
|
|
||||||
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
||||||
|
|
||||||
self.is_fp32 = not (config.bf16 or config.fp16)
|
self.is_fp32 = not (config.bf16 or config.fp16)
|
||||||
|
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[QWenBlock(config) for i in range(config.num_hidden_layers)]
|
||||||
QWenBlock(
|
|
||||||
config
|
|
||||||
)
|
|
||||||
for i in range(config.num_hidden_layers)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
self.ln_f = RMSNorm(
|
self.ln_f = RMSNorm(
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
|
@ -659,7 +694,12 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
ntk_alpha_list = []
|
ntk_alpha_list = []
|
||||||
if attention_mask is not None and kv_seq_len > self.seq_length:
|
if attention_mask is not None and kv_seq_len > self.seq_length:
|
||||||
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
|
true_seq_lens = (
|
||||||
|
attention_mask.squeeze(1)
|
||||||
|
.squeeze(1)
|
||||||
|
.eq(0)
|
||||||
|
.sum(dim=-1, dtype=torch.int32)
|
||||||
|
)
|
||||||
for i in range(hidden_states.size()[0]):
|
for i in range(hidden_states.size()[0]):
|
||||||
true_seq_len = true_seq_lens[i].item()
|
true_seq_len = true_seq_lens[i].item()
|
||||||
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
ntk_alpha = self.get_ntk_alpha(true_seq_len)
|
||||||
|
@ -669,7 +709,8 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
ntk_alpha_list.append(ntk_alpha)
|
ntk_alpha_list.append(ntk_alpha)
|
||||||
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
|
||||||
rotary_pos_emb_list = [
|
rotary_pos_emb_list = [
|
||||||
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
|
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
|
||||||
|
for ntk_alpha in ntk_alpha_list
|
||||||
]
|
]
|
||||||
|
|
||||||
hidden_states = self.drop(hidden_states)
|
hidden_states = self.drop(hidden_states)
|
||||||
|
@ -686,7 +727,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
@ -727,7 +767,9 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
presents = presents + (outputs[1],)
|
presents = presents + (outputs[1],)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (
|
||||||
|
outputs[2 if use_cache else 1],
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
hidden_states = hidden_states.view(output_shape)
|
hidden_states = hidden_states.view(output_shape)
|
||||||
|
@ -756,7 +798,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
assert (
|
assert (
|
||||||
config.bf16 + config.fp16 + config.fp32 <= 1
|
config.bf16 + config.fp16 + config.fp32 <= 1
|
||||||
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
|
), 'Only one of "bf16", "fp16", "fp32" can be true'
|
||||||
|
|
||||||
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
|
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
|
||||||
|
|
||||||
|
@ -764,27 +806,35 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
if SUPPORT_BF16:
|
if SUPPORT_BF16:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"The model is automatically converting to bf16 for faster inference. "
|
"The model is automatically converting to bf16 for faster inference. "
|
||||||
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
|
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
|
||||||
)
|
)
|
||||||
config.bf16 = True
|
config.bf16 = True
|
||||||
elif SUPPORT_FP16:
|
elif SUPPORT_FP16:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"The model is automatically converting to fp16 for faster inference. "
|
"The model is automatically converting to fp16 for faster inference. "
|
||||||
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
|
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
|
||||||
)
|
)
|
||||||
config.fp16 = True
|
config.fp16 = True
|
||||||
else:
|
else:
|
||||||
config.fp32 = True
|
config.fp32 = True
|
||||||
|
|
||||||
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
|
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
|
||||||
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
|
logger.warn(
|
||||||
|
'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".'
|
||||||
|
)
|
||||||
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
|
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
|
||||||
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
|
logger.warn(
|
||||||
|
"Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster"
|
||||||
|
)
|
||||||
if config.fp32:
|
if config.fp32:
|
||||||
if SUPPORT_BF16:
|
if SUPPORT_BF16:
|
||||||
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
logger.warn(
|
||||||
|
'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".'
|
||||||
|
)
|
||||||
elif SUPPORT_FP16:
|
elif SUPPORT_FP16:
|
||||||
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
logger.warn(
|
||||||
|
'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".'
|
||||||
|
)
|
||||||
|
|
||||||
self.transformer = QWenModel(config)
|
self.transformer = QWenModel(config)
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
@ -845,7 +895,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
|
|
||||||
return_dict = (
|
return_dict = (
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
)
|
)
|
||||||
|
@ -887,7 +936,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
# )
|
# )
|
||||||
# loss.backward()
|
# loss.backward()
|
||||||
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
@ -904,7 +952,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
def _reorder_cache(
|
def _reorder_cache(
|
||||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||||
) -> Tuple[Tuple[torch.Tensor]]:
|
) -> Tuple[Tuple[torch.Tensor]]:
|
||||||
|
|
||||||
return tuple(
|
return tuple(
|
||||||
tuple(
|
tuple(
|
||||||
past_state.index_select(0, beam_idx.to(past_state.device))
|
past_state.index_select(0, beam_idx.to(past_state.device))
|
||||||
|
@ -924,10 +971,14 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[str, HistoryType]:
|
) -> Tuple[str, HistoryType]:
|
||||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
generation_config = (
|
||||||
|
generation_config
|
||||||
|
if generation_config is not None
|
||||||
|
else self.generation_config
|
||||||
|
)
|
||||||
|
|
||||||
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
|
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
|
||||||
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
else:
|
else:
|
||||||
|
@ -937,7 +988,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
if stop_words_ids is None:
|
if stop_words_ids is None:
|
||||||
stop_words_ids = []
|
stop_words_ids = []
|
||||||
|
|
||||||
max_window_size = kwargs.get('max_window_size', None)
|
max_window_size = kwargs.get("max_window_size", None)
|
||||||
if max_window_size is None:
|
if max_window_size is None:
|
||||||
max_window_size = generation_config.max_window_size
|
max_window_size = generation_config.max_window_size
|
||||||
raw_text, context_tokens = make_context(
|
raw_text, context_tokens = make_context(
|
||||||
|
@ -949,17 +1000,17 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
chat_format=generation_config.chat_format,
|
chat_format=generation_config.chat_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_words_ids.extend(get_stop_words_ids(
|
stop_words_ids.extend(
|
||||||
generation_config.chat_format, tokenizer
|
get_stop_words_ids(generation_config.chat_format, tokenizer)
|
||||||
))
|
)
|
||||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||||
outputs = self.generate(
|
outputs = self.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
stop_words_ids=stop_words_ids,
|
stop_words_ids=stop_words_ids,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = decode_tokens(
|
response = decode_tokens(
|
||||||
outputs[0],
|
outputs[0],
|
||||||
|
@ -968,7 +1019,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
context_length=len(context_tokens),
|
context_length=len(context_tokens),
|
||||||
chat_format=generation_config.chat_format,
|
chat_format=generation_config.chat_format,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
errors='replace'
|
errors="replace",
|
||||||
)
|
)
|
||||||
|
|
||||||
# as history is a copy of the user inputs,
|
# as history is a copy of the user inputs,
|
||||||
|
@ -980,24 +1031,28 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
return response, history
|
return response, history
|
||||||
|
|
||||||
def chat_stream(
|
def chat_stream(
|
||||||
self,
|
self,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
query: str,
|
query: str,
|
||||||
history: Optional[HistoryType],
|
history: Optional[HistoryType],
|
||||||
system: str = "You are a helpful assistant.",
|
system: str = "You are a helpful assistant.",
|
||||||
stop_words_ids: Optional[List[List[int]]] = None,
|
stop_words_ids: Optional[List[List[int]]] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Generator[str, Any, None]:
|
) -> Generator[str, Any, None]:
|
||||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
generation_config = (
|
||||||
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
generation_config
|
||||||
|
if generation_config is not None
|
||||||
|
else self.generation_config
|
||||||
|
)
|
||||||
|
assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT
|
||||||
if history is None:
|
if history is None:
|
||||||
history = []
|
history = []
|
||||||
if stop_words_ids is None:
|
if stop_words_ids is None:
|
||||||
stop_words_ids = []
|
stop_words_ids = []
|
||||||
|
|
||||||
max_window_size = kwargs.get('max_window_size', None)
|
max_window_size = kwargs.get("max_window_size", None)
|
||||||
if max_window_size is None:
|
if max_window_size is None:
|
||||||
max_window_size = generation_config.max_window_size
|
max_window_size = generation_config.max_window_size
|
||||||
raw_text, context_tokens = make_context(
|
raw_text, context_tokens = make_context(
|
||||||
|
@ -1009,9 +1064,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
chat_format=generation_config.chat_format,
|
chat_format=generation_config.chat_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_words_ids.extend(get_stop_words_ids(
|
stop_words_ids.extend(
|
||||||
generation_config.chat_format, tokenizer
|
get_stop_words_ids(generation_config.chat_format, tokenizer)
|
||||||
))
|
)
|
||||||
if stop_words_ids is not None:
|
if stop_words_ids is not None:
|
||||||
stop_words_logits_processor = StopWordsLogitsProcessor(
|
stop_words_logits_processor = StopWordsLogitsProcessor(
|
||||||
stop_words_ids=stop_words_ids,
|
stop_words_ids=stop_words_ids,
|
||||||
|
@ -1023,22 +1078,31 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
logits_processor.append(stop_words_logits_processor)
|
logits_processor.append(stop_words_logits_processor)
|
||||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||||
|
|
||||||
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
|
from transformers_stream_generator.main import (
|
||||||
|
NewGenerationMixin,
|
||||||
|
StreamGenerationConfig,
|
||||||
|
)
|
||||||
|
|
||||||
self.__class__.generate_stream = NewGenerationMixin.generate
|
self.__class__.generate_stream = NewGenerationMixin.generate
|
||||||
self.__class__.sample_stream = NewGenerationMixin.sample_stream
|
self.__class__.sample_stream = NewGenerationMixin.sample_stream
|
||||||
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
|
stream_config = StreamGenerationConfig(
|
||||||
|
**generation_config.to_dict(), do_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
def stream_generator():
|
def stream_generator():
|
||||||
outputs = []
|
outputs = []
|
||||||
for token in self.generate_stream(
|
for token in self.generate_stream(
|
||||||
input_ids,
|
input_ids,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
generation_config=stream_config,
|
generation_config=stream_config,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
seed=-1,
|
seed=-1,
|
||||||
**kwargs):
|
**kwargs,
|
||||||
|
):
|
||||||
outputs.append(token.item())
|
outputs.append(token.item())
|
||||||
yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
|
yield tokenizer.decode(
|
||||||
|
outputs, skip_special_tokens=True, errors="ignore"
|
||||||
|
)
|
||||||
|
|
||||||
return stream_generator()
|
return stream_generator()
|
||||||
|
|
||||||
|
@ -1056,7 +1120,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
generation_config = generation_config if generation_config is not None else self.generation_config
|
generation_config = (
|
||||||
|
generation_config
|
||||||
|
if generation_config is not None
|
||||||
|
else self.generation_config
|
||||||
|
)
|
||||||
|
|
||||||
# Process stop_words_ids.
|
# Process stop_words_ids.
|
||||||
stop_words_ids = kwargs.pop("stop_words_ids", None)
|
stop_words_ids = kwargs.pop("stop_words_ids", None)
|
||||||
|
@ -1093,7 +1161,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config: Optional[GenerationConfig] = None,
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
prefix_allowed_tokens_fn: Optional[
|
||||||
|
Callable[[int, torch.Tensor], List[int]]
|
||||||
|
] = None,
|
||||||
synced_gpus: Optional[bool] = None,
|
synced_gpus: Optional[bool] = None,
|
||||||
assistant_model: Optional["PreTrainedModel"] = None,
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
|
@ -1101,9 +1171,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
|
|
||||||
if synced_gpus is None:
|
if synced_gpus is None:
|
||||||
synced_gpus = False
|
synced_gpus = False
|
||||||
|
|
||||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||||
self._validate_model_class()
|
self._validate_model_class()
|
||||||
|
@ -1114,8 +1183,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
# two conditions must be met
|
# two conditions must be met
|
||||||
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||||
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
||||||
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
if (
|
||||||
self.generation_config
|
self.generation_config._from_model_config
|
||||||
|
and self.generation_config._original_object_hash
|
||||||
|
== hash(self.generation_config)
|
||||||
):
|
):
|
||||||
new_generation_config = GenerationConfig.from_model_config(self.config)
|
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||||
if new_generation_config != self.generation_config:
|
if new_generation_config != self.generation_config:
|
||||||
|
@ -1129,15 +1200,26 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config = self.generation_config
|
generation_config = self.generation_config
|
||||||
|
|
||||||
generation_config = copy.deepcopy(generation_config)
|
generation_config = copy.deepcopy(generation_config)
|
||||||
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
model_kwargs = generation_config.update(
|
||||||
|
**kwargs
|
||||||
|
) # All unused kwargs must be model kwargs
|
||||||
generation_config.validate()
|
generation_config.validate()
|
||||||
self._validate_model_kwargs(model_kwargs.copy())
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
# 2. Set generation parameters if not already defined
|
# 2. Set generation parameters if not already defined
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = (
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
|
)
|
||||||
|
stopping_criteria = (
|
||||||
|
stopping_criteria
|
||||||
|
if stopping_criteria is not None
|
||||||
|
else StoppingCriteriaList()
|
||||||
|
)
|
||||||
|
|
||||||
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
if (
|
||||||
|
generation_config.pad_token_id is None
|
||||||
|
and generation_config.eos_token_id is not None
|
||||||
|
):
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
|
@ -1146,7 +1228,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
eos_token_id = generation_config.eos_token_id
|
eos_token_id = generation_config.eos_token_id
|
||||||
if isinstance(eos_token_id, list):
|
if isinstance(eos_token_id, list):
|
||||||
eos_token_id = eos_token_id[0]
|
eos_token_id = eos_token_id[0]
|
||||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
logger.warning(
|
||||||
|
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
|
||||||
|
)
|
||||||
generation_config.pad_token_id = eos_token_id
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
# 3. Define model inputs
|
# 3. Define model inputs
|
||||||
|
@ -1169,12 +1253,22 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
model_kwargs["use_cache"] = generation_config.use_cache
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
|
||||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
accepts_attention_mask = "attention_mask" in set(
|
||||||
|
inspect.signature(self.forward).parameters.keys()
|
||||||
|
)
|
||||||
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
if (
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs.get("attention_mask", None) is None
|
||||||
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
and requires_attention_mask
|
||||||
|
and accepts_attention_mask
|
||||||
|
):
|
||||||
|
model_kwargs[
|
||||||
|
"attention_mask"
|
||||||
|
] = self._prepare_attention_mask_for_generation(
|
||||||
|
inputs_tensor,
|
||||||
|
generation_config.pad_token_id,
|
||||||
|
generation_config.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder-only models should use left-padding for generation
|
# decoder-only models should use left-padding for generation
|
||||||
|
@ -1184,7 +1278,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
if (
|
if (
|
||||||
generation_config.pad_token_id is not None
|
generation_config.pad_token_id is not None
|
||||||
and len(inputs_tensor.shape) == 2
|
and len(inputs_tensor.shape) == 2
|
||||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
|
||||||
|
> 0
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||||
|
@ -1209,14 +1304,21 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
device=inputs_tensor.device,
|
device=inputs_tensor.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
input_ids = (
|
||||||
|
inputs_tensor
|
||||||
|
if model_input_name == "input_ids"
|
||||||
|
else model_kwargs.pop("input_ids")
|
||||||
|
)
|
||||||
|
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
streamer.put(input_ids.cpu())
|
streamer.put(input_ids.cpu())
|
||||||
|
|
||||||
# 6. Prepare `max_length` depending on other stopping criteria.
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||||
input_ids_length = input_ids.shape[-1]
|
input_ids_length = input_ids.shape[-1]
|
||||||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
has_default_max_length = (
|
||||||
|
kwargs.get("max_length") is None
|
||||||
|
and generation_config.max_length is not None
|
||||||
|
)
|
||||||
if generation_config.max_new_tokens is not None:
|
if generation_config.max_new_tokens is not None:
|
||||||
if not has_default_max_length and generation_config.max_length is not None:
|
if not has_default_max_length and generation_config.max_length is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -1225,8 +1327,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
"Please refer to the documentation for more information. "
|
"Please refer to the documentation for more information. "
|
||||||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||||
)
|
)
|
||||||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
generation_config.max_length = (
|
||||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
generation_config.max_new_tokens + input_ids_length
|
||||||
|
)
|
||||||
|
self._validate_generated_length(
|
||||||
|
generation_config, input_ids_length, has_default_max_length
|
||||||
|
)
|
||||||
|
|
||||||
# 7. determine generation mode
|
# 7. determine generation mode
|
||||||
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
||||||
|
@ -1291,7 +1397,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sample_base(
|
def sample_base(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: torch.LongTensor,
|
||||||
|
@ -1309,10 +1414,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
streamer: Optional["BaseStreamer"] = None,
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
# init values
|
# init values
|
||||||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
logits_processor = (
|
||||||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
|
)
|
||||||
|
stopping_criteria = (
|
||||||
|
stopping_criteria
|
||||||
|
if stopping_criteria is not None
|
||||||
|
else StoppingCriteriaList()
|
||||||
|
)
|
||||||
# if max_length is not None:
|
# if max_length is not None:
|
||||||
# warnings.warn(
|
# warnings.warn(
|
||||||
# "`max_length` is deprecated in this function, use"
|
# "`max_length` is deprecated in this function, use"
|
||||||
|
@ -1320,18 +1430,40 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
# UserWarning,
|
# UserWarning,
|
||||||
# )
|
# )
|
||||||
# stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
# stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
logits_warper = (
|
||||||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
)
|
||||||
|
pad_token_id = (
|
||||||
|
pad_token_id
|
||||||
|
if pad_token_id is not None
|
||||||
|
else self.generation_config.pad_token_id
|
||||||
|
)
|
||||||
|
eos_token_id = (
|
||||||
|
eos_token_id
|
||||||
|
if eos_token_id is not None
|
||||||
|
else self.generation_config.eos_token_id
|
||||||
|
)
|
||||||
if isinstance(eos_token_id, int):
|
if isinstance(eos_token_id, int):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
eos_token_id_tensor = (
|
||||||
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
torch.tensor(eos_token_id).to(input_ids.device)
|
||||||
|
if eos_token_id is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
output_scores = (
|
||||||
|
output_scores
|
||||||
|
if output_scores is not None
|
||||||
|
else self.generation_config.output_scores
|
||||||
|
)
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.generation_config.output_attentions
|
||||||
)
|
)
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.generation_config.output_hidden_states
|
||||||
)
|
)
|
||||||
return_dict_in_generate = (
|
return_dict_in_generate = (
|
||||||
return_dict_in_generate
|
return_dict_in_generate
|
||||||
|
@ -1341,19 +1473,33 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
# init attention / hidden states / scores tuples
|
# init attention / hidden states / scores tuples
|
||||||
scores = () if (return_dict_in_generate and output_scores) else None
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
decoder_attentions = (
|
||||||
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
() if (return_dict_in_generate and output_attentions) else None
|
||||||
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
)
|
||||||
|
cross_attentions = (
|
||||||
|
() if (return_dict_in_generate and output_attentions) else None
|
||||||
|
)
|
||||||
|
decoder_hidden_states = (
|
||||||
|
() if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
)
|
||||||
|
|
||||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||||
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
encoder_attentions = (
|
||||||
|
model_kwargs["encoder_outputs"].get("attentions")
|
||||||
|
if output_attentions
|
||||||
|
else None
|
||||||
|
)
|
||||||
encoder_hidden_states = (
|
encoder_hidden_states = (
|
||||||
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
model_kwargs["encoder_outputs"].get("hidden_states")
|
||||||
|
if output_hidden_states
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
unfinished_sequences = torch.ones(
|
||||||
|
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||||
|
)
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
# auto-regressive generation
|
# auto-regressive generation
|
||||||
|
@ -1394,7 +1540,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
scores += (next_token_scores,)
|
scores += (next_token_scores,)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
decoder_attentions += (
|
decoder_attentions += (
|
||||||
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
(outputs.decoder_attentions,)
|
||||||
|
if self.config.is_encoder_decoder
|
||||||
|
else (outputs.attentions,)
|
||||||
)
|
)
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder:
|
||||||
cross_attentions += (outputs.cross_attentions,)
|
cross_attentions += (outputs.cross_attentions,)
|
||||||
|
@ -1413,8 +1561,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
# finished sentences should have their next token be a padding token
|
# finished sentences should have their next token be a padding token
|
||||||
if eos_token_id is not None:
|
if eos_token_id is not None:
|
||||||
if pad_token_id is None:
|
if pad_token_id is None:
|
||||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
raise ValueError(
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||||
|
)
|
||||||
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
||||||
|
1 - unfinished_sequences
|
||||||
|
)
|
||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
# update generated ids, model inputs, and length for next step
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
@ -1427,7 +1579,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
if eos_token_id_tensor is not None:
|
if eos_token_id_tensor is not None:
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
|
||||||
|
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||||
|
.prod(dim=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
# stop when each sentence is finished
|
# stop when each sentence is finished
|
||||||
|
@ -1446,7 +1600,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
# def backward(
|
# def backward(
|
||||||
# self,
|
# self,
|
||||||
# tokenizer,
|
# tokenizer,
|
||||||
|
@ -1539,7 +1692,7 @@ def _rotate_half(x):
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(t, freqs):
|
def apply_rotary_pos_emb(t, freqs):
|
||||||
""" Apply rotary embedding to the first rotary_dim of the iput
|
"""Apply rotary embedding to the first rotary_dim of the iput
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
t (tensor(batch_size, seq_len, n_head, head_dim)):
|
t (tensor(batch_size, seq_len, n_head, head_dim)):
|
||||||
|
|
Loading…
Reference in New Issue