Refine model of qwen.
This commit is contained in:
parent
611396b656
commit
90cb0fe236
|
@ -36,12 +36,6 @@ except ImportError:
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
SUPPORT_CUDA = torch.cuda.is_available()
|
SUPPORT_CUDA = torch.cuda.is_available()
|
||||||
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
|
||||||
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
|
||||||
SUPPORT_TORCH2 = (
|
|
||||||
hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
from configuration_qwen import QWenConfig
|
from configuration_qwen import QWenConfig
|
||||||
from qwen_generation_utils import (
|
from qwen_generation_utils import (
|
||||||
|
@ -69,37 +63,6 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
|
||||||
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
|
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
apply_rotary_emb_func = None
|
|
||||||
rms_norm = None
|
|
||||||
|
|
||||||
|
|
||||||
def quantize_cache_v(fdata, bits, qmax, qmin):
|
|
||||||
# b, s, head, h-dim->b, head, s, h-dim
|
|
||||||
qtype = torch.uint8
|
|
||||||
device = fdata.device
|
|
||||||
shape = fdata.shape
|
|
||||||
|
|
||||||
fdata_cal = torch.flatten(fdata, 2)
|
|
||||||
fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)
|
|
||||||
fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)
|
|
||||||
# Compute params
|
|
||||||
if qmax.device != fmax.device:
|
|
||||||
qmax = qmax.to(device)
|
|
||||||
qmin = qmin.to(device)
|
|
||||||
scale = (fmax - fmin) / (qmax - qmin)
|
|
||||||
zero = qmin - fmin / scale
|
|
||||||
scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
|
|
||||||
zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
|
|
||||||
# Quantize
|
|
||||||
res_data = fdata / scale + zero
|
|
||||||
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
|
|
||||||
return qdata.contiguous(), scale, zero
|
|
||||||
|
|
||||||
|
|
||||||
def dequantize_cache_torch(qdata, scale, zero):
|
|
||||||
data = scale * (qdata - zero)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class QWenAttention(nn.Module):
|
class QWenAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -128,9 +91,7 @@ class QWenAttention(nn.Module):
|
||||||
config.hidden_size, self.projection_size, bias=not config.no_bias
|
config.hidden_size, self.projection_size, bias=not config.no_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
self.is_fp32 = not (config.bf16 or config.fp16)
|
self.is_fp32 = True
|
||||||
|
|
||||||
self.bf16 = config.bf16
|
|
||||||
|
|
||||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||||
self.use_logn_attn = config.use_logn_attn
|
self.use_logn_attn = config.use_logn_attn
|
||||||
|
@ -146,128 +107,13 @@ class QWenAttention(nn.Module):
|
||||||
self.softmax_in_fp32 = (
|
self.softmax_in_fp32 = (
|
||||||
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
||||||
)
|
)
|
||||||
self.use_cache_quantization = (
|
|
||||||
config.use_cache_quantization
|
|
||||||
if hasattr(config, "use_cache_quantization")
|
|
||||||
else False
|
|
||||||
)
|
|
||||||
self.use_cache_kernel = (
|
self.use_cache_kernel = (
|
||||||
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
||||||
)
|
)
|
||||||
cache_dtype = torch.float
|
cache_dtype = torch.float
|
||||||
if self.bf16:
|
|
||||||
cache_dtype = torch.bfloat16
|
|
||||||
elif config.fp16:
|
|
||||||
cache_dtype = torch.float16
|
|
||||||
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
||||||
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
||||||
|
|
||||||
if config.use_cache_quantization and config.use_cache_kernel:
|
|
||||||
# pre check if the support files existing
|
|
||||||
module_root = pathlib.Path(__file__).parent
|
|
||||||
src_files = (
|
|
||||||
"cache_autogptq_cuda_256.cpp",
|
|
||||||
"cache_autogptq_cuda_kernel_256.cu",
|
|
||||||
)
|
|
||||||
if any(not (module_root / src).is_file() for src in src_files):
|
|
||||||
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
|
|
||||||
self.cache_kernels = None
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
from .cpp_kernels import cache_autogptq_cuda_256
|
|
||||||
|
|
||||||
self.cache_kernels = cache_autogptq_cuda_256
|
|
||||||
except ImportError:
|
|
||||||
warnings.warn("Failed to import KV cache kernels.")
|
|
||||||
self.cache_kernels = None
|
|
||||||
|
|
||||||
def _attn(
|
|
||||||
self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None
|
|
||||||
):
|
|
||||||
device = query.device
|
|
||||||
if self.use_cache_quantization:
|
|
||||||
qk, qk_scale, qk_zero = key
|
|
||||||
if self.use_cache_kernel and self.cache_kernels is not None:
|
|
||||||
shape = query.shape[:-1] + (qk.shape[-2],)
|
|
||||||
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
|
||||||
self.cache_kernels.vecquant8matmul_batched_faster_old(
|
|
||||||
query.contiguous()
|
|
||||||
if query.dtype == torch.float16
|
|
||||||
else query.to(torch.float16).contiguous(),
|
|
||||||
qk.transpose(-1, -2).contiguous(),
|
|
||||||
attn_weights,
|
|
||||||
qk_scale.contiguous()
|
|
||||||
if qk_scale.dtype == torch.float16
|
|
||||||
else qk_scale.to(torch.float16).contiguous(),
|
|
||||||
qk_zero.contiguous()
|
|
||||||
if qk_zero.dtype == torch.float16
|
|
||||||
else qk_zero.to(torch.float16).contiguous(),
|
|
||||||
)
|
|
||||||
# attn_weights = attn_weights.to(query.dtype).contiguous()
|
|
||||||
else:
|
|
||||||
key = dequantize_cache_torch(qk, qk_scale, qk_zero)
|
|
||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
|
||||||
else:
|
|
||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
|
||||||
|
|
||||||
if self.scale_attn_weights:
|
|
||||||
if self.use_cache_quantization:
|
|
||||||
size_temp = value[0].size(-1)
|
|
||||||
else:
|
|
||||||
size_temp = value.size(-1)
|
|
||||||
attn_weights = attn_weights / (size_temp**0.5)
|
|
||||||
|
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
|
||||||
if causal_mask is not None:
|
|
||||||
attn_weights = torch.where(
|
|
||||||
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
if self.softmax_in_fp32:
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
|
|
||||||
else:
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
||||||
|
|
||||||
attn_weights = attn_weights.type(query.dtype)
|
|
||||||
attn_weights = self.attn_dropout(attn_weights)
|
|
||||||
|
|
||||||
if head_mask is not None:
|
|
||||||
attn_weights = attn_weights * head_mask
|
|
||||||
|
|
||||||
if self.use_cache_quantization:
|
|
||||||
qv, qv_scale, qv_zero = value
|
|
||||||
if self.use_cache_kernel and self.cache_kernels is not None:
|
|
||||||
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
|
||||||
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
|
||||||
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
|
|
||||||
attn_weights.contiguous()
|
|
||||||
if attn_weights.dtype == torch.float16
|
|
||||||
else attn_weights.to(torch.float16).contiguous(),
|
|
||||||
qv.contiguous(), # dtype: int32
|
|
||||||
attn_output,
|
|
||||||
qv_scale.contiguous()
|
|
||||||
if qv_scale.dtype == torch.float16
|
|
||||||
else qv_scale.to(torch.float16).contiguous(),
|
|
||||||
qv_zero.contiguous()
|
|
||||||
if qv_zero.dtype == torch.float16
|
|
||||||
else qv_zero.to(torch.float16).contiguous(),
|
|
||||||
)
|
|
||||||
if attn_output.dtype != query.dtype:
|
|
||||||
attn_output = attn_output.to(query.dtype)
|
|
||||||
attn_weights = attn_weights.to(query.dtype)
|
|
||||||
else:
|
|
||||||
value = dequantize_cache_torch(qv, qv_scale, qv_zero)
|
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
|
||||||
else:
|
|
||||||
attn_output = torch.matmul(attn_weights, value)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
|
||||||
|
|
||||||
return attn_output, attn_weights
|
|
||||||
|
|
||||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||||
tensor = tensor.view(new_shape)
|
tensor = tensor.view(new_shape)
|
||||||
|
@ -323,59 +169,25 @@ class QWenAttention(nn.Module):
|
||||||
query = torch.cat(query_list, dim=0)
|
query = torch.cat(query_list, dim=0)
|
||||||
key = torch.cat(key_list, dim=0)
|
key = torch.cat(key_list, dim=0)
|
||||||
|
|
||||||
if self.use_cache_quantization:
|
|
||||||
key = quantize_cache_v(
|
|
||||||
key.permute(0, 2, 1, 3),
|
|
||||||
bits=8,
|
|
||||||
qmin=self.cache_qmin,
|
|
||||||
qmax=self.cache_qmax,
|
|
||||||
)
|
|
||||||
value = quantize_cache_v(
|
|
||||||
value.permute(0, 2, 1, 3),
|
|
||||||
bits=8,
|
|
||||||
qmin=self.cache_qmin,
|
|
||||||
qmax=self.cache_qmax,
|
|
||||||
)
|
|
||||||
|
|
||||||
if layer_past is not None:
|
if layer_past is not None:
|
||||||
past_key, past_value = layer_past[0], layer_past[1]
|
past_key, past_value = layer_past[0], layer_past[1]
|
||||||
if self.use_cache_quantization:
|
|
||||||
# use_cache_quantization:
|
key = torch.cat((past_key, key), dim=1)
|
||||||
# present=((q_key,key_scale,key_zero_point),
|
value = torch.cat((past_value, value), dim=1)
|
||||||
# (q_value,value_scale,value_zero_point))
|
|
||||||
key = (
|
|
||||||
torch.cat((past_key[0], key[0]), dim=2),
|
|
||||||
torch.cat((past_key[1], key[1]), dim=2),
|
|
||||||
torch.cat((past_key[2], key[2]), dim=2),
|
|
||||||
)
|
|
||||||
value = (
|
|
||||||
torch.cat((past_value[0], value[0]), dim=2),
|
|
||||||
torch.cat((past_value[1], value[1]), dim=2),
|
|
||||||
torch.cat((past_value[2], value[2]), dim=2),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# not use_cache_quantization:
|
|
||||||
# present=(key,value)
|
|
||||||
key = torch.cat((past_key, key), dim=1)
|
|
||||||
value = torch.cat((past_value, value), dim=1)
|
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
present = (key, value)
|
present = (key, value)
|
||||||
else:
|
else:
|
||||||
present = None
|
present = None
|
||||||
|
|
||||||
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
|
key_size = key.size(1)
|
||||||
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
||||||
if self.use_cache_quantization:
|
seq_start = key.size(1) - query.size(1)
|
||||||
seq_start = key[0].size(2) - query.size(1)
|
seq_end = key.size(1)
|
||||||
seq_end = key[0].size(2)
|
|
||||||
else:
|
|
||||||
seq_start = key.size(1) - query.size(1)
|
|
||||||
seq_end = key.size(1)
|
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||||
query = query * logn_tensor.expand_as(query)
|
query = query * logn_tensor.expand_as(query)
|
||||||
|
|
||||||
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
|
key_size = key.size(1)
|
||||||
if query.size(1) == key_size:
|
if query.size(1) == key_size:
|
||||||
causal_mask = torch.tril(
|
causal_mask = torch.tril(
|
||||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
||||||
|
@ -383,39 +195,31 @@ class QWenAttention(nn.Module):
|
||||||
else:
|
else:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
query = query.permute(0, 2, 1, 3)
|
query = query.permute(0, 2, 1, 3)
|
||||||
if not self.use_cache_quantization:
|
key = key.permute(0, 2, 1, 3)
|
||||||
key = key.permute(0, 2, 1, 3)
|
value = value.permute(0, 2, 1, 3)
|
||||||
value = value.permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
if attention_mask is not None:
|
||||||
if attention_mask is not None:
|
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
||||||
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
if causal_mask is not None:
|
||||||
if causal_mask is not None:
|
attention_mask = attention_mask.masked_fill(
|
||||||
attention_mask = attention_mask.masked_fill(
|
~causal_mask, torch.finfo(query.dtype).min
|
||||||
~causal_mask, torch.finfo(query.dtype).min
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
attention_mask = causal_mask
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask
|
|
||||||
).transpose(1, 2)
|
|
||||||
attn_weight = None
|
|
||||||
else:
|
else:
|
||||||
attn_output, attn_weight = self._attn(
|
attention_mask = causal_mask
|
||||||
query, key, value, causal_mask, attention_mask, head_mask
|
attn_output = F.scaled_dot_product_attention(
|
||||||
)
|
query, key, value, attn_mask=attention_mask
|
||||||
|
).transpose(1, 2)
|
||||||
|
attn_weight = None
|
||||||
|
|
||||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
attn_output = self.c_proj(context_layer)
|
attn_output = self.c_proj(context_layer)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, present)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
raise ValueError(
|
||||||
raise ValueError(
|
"Cannot output attentions while using scaled_dot_product_attention"
|
||||||
"Cannot output attentions while using scaled_dot_product_attention"
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
outputs += (attn_weight,)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -444,7 +248,6 @@ class QWenBlock(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
self.bf16 = config.bf16
|
|
||||||
|
|
||||||
self.ln_1 = RMSNorm(
|
self.ln_1 = RMSNorm(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -549,11 +352,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.use_cache_quantization = (
|
|
||||||
self.config.use_cache_quantization
|
|
||||||
if hasattr(self.config, "use_cache_quantization")
|
|
||||||
else False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||||
|
@ -571,7 +369,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
|
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
|
||||||
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
||||||
|
|
||||||
self.is_fp32 = not (config.bf16 or config.fp16)
|
self.is_fp32 = True
|
||||||
|
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[QWenBlock(config) for i in range(config.num_hidden_layers)]
|
[QWenBlock(config) for i in range(config.num_hidden_layers)]
|
||||||
|
@ -651,10 +449,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
past_length = 0
|
past_length = 0
|
||||||
past_key_values = tuple([None] * len(self.h))
|
past_key_values = tuple([None] * len(self.h))
|
||||||
else:
|
else:
|
||||||
if self.use_cache_quantization:
|
past_length = past_key_values[0][0].size(-2)
|
||||||
past_length = past_key_values[0][0][0].size(2)
|
|
||||||
else:
|
|
||||||
past_length = past_key_values[0][0].size(-2)
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
past_length,
|
past_length,
|
||||||
|
@ -682,10 +477,7 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
kv_seq_len = hidden_states.size()[1]
|
kv_seq_len = hidden_states.size()[1]
|
||||||
if past_key_values[0] is not None:
|
if past_key_values[0] is not None:
|
||||||
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
||||||
if self.use_cache_quantization:
|
kv_seq_len += past_key_values[0][0].shape[1]
|
||||||
kv_seq_len += past_key_values[0][0][0].shape[2]
|
|
||||||
else:
|
|
||||||
kv_seq_len += past_key_values[0][0].shape[1]
|
|
||||||
|
|
||||||
if self.training or not self.use_dynamic_ntk:
|
if self.training or not self.use_dynamic_ntk:
|
||||||
ntk_alpha_list = [1.0]
|
ntk_alpha_list = [1.0]
|
||||||
|
@ -796,55 +588,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
assert (
|
|
||||||
config.bf16 + config.fp16 + config.fp32 <= 1
|
|
||||||
), 'Only one of "bf16", "fp16", "fp32" can be true'
|
|
||||||
|
|
||||||
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
|
|
||||||
|
|
||||||
if autoset_precision:
|
|
||||||
if SUPPORT_BF16:
|
|
||||||
logger.warn(
|
|
||||||
"The model is automatically converting to bf16 for faster inference. "
|
|
||||||
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
|
|
||||||
)
|
|
||||||
config.bf16 = True
|
|
||||||
elif SUPPORT_FP16:
|
|
||||||
logger.warn(
|
|
||||||
"The model is automatically converting to fp16 for faster inference. "
|
|
||||||
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
|
|
||||||
)
|
|
||||||
config.fp16 = True
|
|
||||||
else:
|
|
||||||
config.fp32 = True
|
|
||||||
|
|
||||||
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
|
|
||||||
logger.warn(
|
|
||||||
'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".'
|
|
||||||
)
|
|
||||||
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
|
|
||||||
logger.warn(
|
|
||||||
"Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster"
|
|
||||||
)
|
|
||||||
if config.fp32:
|
|
||||||
if SUPPORT_BF16:
|
|
||||||
logger.warn(
|
|
||||||
'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".'
|
|
||||||
)
|
|
||||||
elif SUPPORT_FP16:
|
|
||||||
logger.warn(
|
|
||||||
'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.transformer = QWenModel(config)
|
self.transformer = QWenModel(config)
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
if config.bf16:
|
|
||||||
self.transformer.bfloat16()
|
|
||||||
self.lm_head.bfloat16()
|
|
||||||
if config.fp16:
|
|
||||||
self.transformer.half()
|
|
||||||
self.lm_head.half()
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
|
@ -928,13 +674,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
||||||
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
# loss_fct = CrossEntropyLoss()
|
loss_fct = CrossEntropyLoss()
|
||||||
# loss = loss_fct(
|
loss = loss_fct(
|
||||||
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||||
# )
|
)
|
||||||
# loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
|
@ -948,18 +694,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
attentions=transformer_outputs.attentions,
|
attentions=transformer_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(
|
|
||||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
|
||||||
) -> Tuple[Tuple[torch.Tensor]]:
|
|
||||||
return tuple(
|
|
||||||
tuple(
|
|
||||||
past_state.index_select(0, beam_idx.to(past_state.device))
|
|
||||||
for past_state in layer_past
|
|
||||||
)
|
|
||||||
for layer_past in past_key_values
|
|
||||||
)
|
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
@ -1171,9 +905,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
if synced_gpus is None:
|
|
||||||
synced_gpus = False
|
|
||||||
|
|
||||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||||
self._validate_model_class()
|
self._validate_model_class()
|
||||||
|
|
||||||
|
@ -1271,44 +1002,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
generation_config.eos_token_id,
|
generation_config.eos_token_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder-only models should use left-padding for generation
|
|
||||||
if not self.config.is_encoder_decoder:
|
|
||||||
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
|
||||||
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
|
||||||
if (
|
|
||||||
generation_config.pad_token_id is not None
|
|
||||||
and len(inputs_tensor.shape) == 2
|
|
||||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
|
|
||||||
> 0
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
|
||||||
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
|
||||||
# if model is encoder decoder encoder_outputs are created
|
|
||||||
# and added to `model_kwargs`
|
|
||||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
|
||||||
inputs_tensor, model_kwargs, model_input_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
if self.config.is_encoder_decoder:
|
|
||||||
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
input_ids = (
|
||||||
batch_size=batch_size,
|
inputs_tensor
|
||||||
model_input_name=model_input_name,
|
if model_input_name == "input_ids"
|
||||||
model_kwargs=model_kwargs,
|
else model_kwargs.pop("input_ids")
|
||||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
)
|
||||||
bos_token_id=generation_config.bos_token_id,
|
|
||||||
device=inputs_tensor.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
input_ids = (
|
|
||||||
inputs_tensor
|
|
||||||
if model_input_name == "input_ids"
|
|
||||||
else model_kwargs.pop("input_ids")
|
|
||||||
)
|
|
||||||
|
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
streamer.put(input_ids.cpu())
|
streamer.put(input_ids.cpu())
|
||||||
|
@ -1378,7 +1078,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
expand_size=generation_config.num_return_sequences,
|
expand_size=generation_config.num_return_sequences,
|
||||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
is_encoder_decoder=False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1483,19 +1183,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
() if (return_dict_in_generate and output_hidden_states) else None
|
() if (return_dict_in_generate and output_hidden_states) else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
|
||||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
|
||||||
encoder_attentions = (
|
|
||||||
model_kwargs["encoder_outputs"].get("attentions")
|
|
||||||
if output_attentions
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
encoder_hidden_states = (
|
|
||||||
model_kwargs["encoder_outputs"].get("hidden_states")
|
|
||||||
if output_hidden_states
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# keep track of which sequences are already finished
|
# keep track of which sequences are already finished
|
||||||
unfinished_sequences = torch.ones(
|
unfinished_sequences = torch.ones(
|
||||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||||
|
@ -1504,16 +1191,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
this_peer_finished = False # used by synced_gpus only
|
this_peer_finished = False # used by synced_gpus only
|
||||||
# auto-regressive generation
|
# auto-regressive generation
|
||||||
while True:
|
while True:
|
||||||
# if synced_gpus:
|
|
||||||
# # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
|
||||||
# # The following logic allows an early break if all peers finished generating their sequence
|
|
||||||
# this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
|
||||||
# # send 0.0 if we finished, 1.0 otherwise
|
|
||||||
# dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
|
||||||
# # did all peers finish? the reduced sum will be 0.0 then
|
|
||||||
# if this_peer_finished_flag.item() == 0.0:
|
|
||||||
# break
|
|
||||||
|
|
||||||
# prepare model inputs
|
# prepare model inputs
|
||||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||||
|
|
||||||
|
@ -1525,9 +1202,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
if synced_gpus and this_peer_finished:
|
|
||||||
continue # don't waste resources running the code we don't need
|
|
||||||
|
|
||||||
next_token_logits = outputs.logits[:, -1, :]
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
# pre-process distribution
|
# pre-process distribution
|
||||||
|
@ -1539,20 +1213,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
if output_scores:
|
if output_scores:
|
||||||
scores += (next_token_scores,)
|
scores += (next_token_scores,)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
decoder_attentions += (
|
decoder_attentions += (outputs.attentions,)
|
||||||
(outputs.decoder_attentions,)
|
|
||||||
if self.config.is_encoder_decoder
|
|
||||||
else (outputs.attentions,)
|
|
||||||
)
|
|
||||||
if self.config.is_encoder_decoder:
|
|
||||||
cross_attentions += (outputs.cross_attentions,)
|
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
decoder_hidden_states += (
|
decoder_hidden_states += (outputs.hidden_states,)
|
||||||
(outputs.decoder_hidden_states,)
|
|
||||||
if self.config.is_encoder_decoder
|
|
||||||
else (outputs.hidden_states,)
|
|
||||||
)
|
|
||||||
|
|
||||||
# sample
|
# sample
|
||||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||||
|
@ -1573,7 +1237,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
streamer.put(next_tokens.cpu())
|
streamer.put(next_tokens.cpu())
|
||||||
model_kwargs = self._update_model_kwargs_for_generation(
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
outputs, model_kwargs, is_encoder_decoder=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
|
@ -1592,7 +1256,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
if stopping_criteria(input_ids, scores):
|
if stopping_criteria(input_ids, scores):
|
||||||
this_peer_finished = True
|
this_peer_finished = True
|
||||||
|
|
||||||
if this_peer_finished and not synced_gpus:
|
if this_peer_finished:
|
||||||
break
|
break
|
||||||
|
|
||||||
if streamer is not None:
|
if streamer is not None:
|
||||||
|
@ -1600,44 +1264,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
# def backward(
|
|
||||||
# self,
|
|
||||||
# tokenizer,
|
|
||||||
# query: str,
|
|
||||||
# ):
|
|
||||||
# inputs = tokenizer.build_chat_input(query, history=[], role="user")
|
|
||||||
# inputs = inputs.to(next(self.parameters()).device)
|
|
||||||
|
|
||||||
# generation_config = copy.deepcopy(self.generation_config)
|
|
||||||
# inputs_tensor = inputs["input_ids"]
|
|
||||||
# input_ids = inputs_tensor.repeat_interleave(
|
|
||||||
# generation_config.num_return_sequences, dim=0
|
|
||||||
# )
|
|
||||||
|
|
||||||
# input_ids_in = input_ids
|
|
||||||
# batch_size, seq_length = input_ids_in.shape
|
|
||||||
# position_ids_in = (
|
|
||||||
# torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
|
||||||
# .unsqueeze(0)
|
|
||||||
# .repeat(batch_size, 1)
|
|
||||||
# )
|
|
||||||
# model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
|
|
||||||
|
|
||||||
# probs, next_tokens = self.transformer(
|
|
||||||
# **model_inputs,
|
|
||||||
# output_hidden_states=None,
|
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
|
||||||
# # probs_target = probs
|
|
||||||
# # probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1
|
|
||||||
|
|
||||||
# loss = probs[0, next_tokens]
|
|
||||||
# loss.backward()
|
|
||||||
|
|
||||||
# return loss
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(torch.nn.Module):
|
class RotaryEmbedding(torch.nn.Module):
|
||||||
def __init__(self, dim, base=10000):
|
def __init__(self, dim, base=10000):
|
||||||
|
@ -1703,17 +1329,9 @@ def apply_rotary_pos_emb(t, freqs):
|
||||||
rot_dim = freqs[0].shape[-1]
|
rot_dim = freqs[0].shape[-1]
|
||||||
cos, sin = freqs
|
cos, sin = freqs
|
||||||
t_float = t.float()
|
t_float = t.float()
|
||||||
if apply_rotary_emb_func is not None and t.is_cuda:
|
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
|
||||||
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
|
||||||
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
||||||
# to the first rotary_dim of the input
|
|
||||||
cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
|
||||||
sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
|
||||||
return apply_rotary_emb_func(t_float, cos, sin).type_as(t)
|
|
||||||
else:
|
|
||||||
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
|
|
||||||
t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
|
|
||||||
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
|
@ -1726,8 +1344,5 @@ class RMSNorm(torch.nn.Module):
|
||||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if rms_norm is not None and x.is_cuda:
|
output = self._norm(x.float()).type_as(x)
|
||||||
return rms_norm(x, self.weight, self.eps)
|
return output * self.weight
|
||||||
else:
|
|
||||||
output = self._norm(x.float()).type_as(x)
|
|
||||||
return output * self.weight
|
|
||||||
|
|
Loading…
Reference in New Issue