Refine model of qwen.
This commit is contained in:
parent
611396b656
commit
90cb0fe236
|
@ -36,12 +36,6 @@ except ImportError:
|
|||
from torch import nn
|
||||
|
||||
SUPPORT_CUDA = torch.cuda.is_available()
|
||||
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
|
||||
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
|
||||
SUPPORT_TORCH2 = (
|
||||
hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2
|
||||
)
|
||||
|
||||
|
||||
from configuration_qwen import QWenConfig
|
||||
from qwen_generation_utils import (
|
||||
|
@ -69,37 +63,6 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
|
|||
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
|
||||
"""
|
||||
|
||||
apply_rotary_emb_func = None
|
||||
rms_norm = None
|
||||
|
||||
|
||||
def quantize_cache_v(fdata, bits, qmax, qmin):
|
||||
# b, s, head, h-dim->b, head, s, h-dim
|
||||
qtype = torch.uint8
|
||||
device = fdata.device
|
||||
shape = fdata.shape
|
||||
|
||||
fdata_cal = torch.flatten(fdata, 2)
|
||||
fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)
|
||||
fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)
|
||||
# Compute params
|
||||
if qmax.device != fmax.device:
|
||||
qmax = qmax.to(device)
|
||||
qmin = qmin.to(device)
|
||||
scale = (fmax - fmin) / (qmax - qmin)
|
||||
zero = qmin - fmin / scale
|
||||
scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
|
||||
zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
|
||||
# Quantize
|
||||
res_data = fdata / scale + zero
|
||||
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
|
||||
return qdata.contiguous(), scale, zero
|
||||
|
||||
|
||||
def dequantize_cache_torch(qdata, scale, zero):
|
||||
data = scale * (qdata - zero)
|
||||
return data
|
||||
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
|
@ -128,9 +91,7 @@ class QWenAttention(nn.Module):
|
|||
config.hidden_size, self.projection_size, bias=not config.no_bias
|
||||
)
|
||||
|
||||
self.is_fp32 = not (config.bf16 or config.fp16)
|
||||
|
||||
self.bf16 = config.bf16
|
||||
self.is_fp32 = True
|
||||
|
||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||
self.use_logn_attn = config.use_logn_attn
|
||||
|
@ -146,128 +107,13 @@ class QWenAttention(nn.Module):
|
|||
self.softmax_in_fp32 = (
|
||||
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
|
||||
)
|
||||
self.use_cache_quantization = (
|
||||
config.use_cache_quantization
|
||||
if hasattr(config, "use_cache_quantization")
|
||||
else False
|
||||
)
|
||||
self.use_cache_kernel = (
|
||||
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
|
||||
)
|
||||
cache_dtype = torch.float
|
||||
if self.bf16:
|
||||
cache_dtype = torch.bfloat16
|
||||
elif config.fp16:
|
||||
cache_dtype = torch.float16
|
||||
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
|
||||
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
|
||||
|
||||
if config.use_cache_quantization and config.use_cache_kernel:
|
||||
# pre check if the support files existing
|
||||
module_root = pathlib.Path(__file__).parent
|
||||
src_files = (
|
||||
"cache_autogptq_cuda_256.cpp",
|
||||
"cache_autogptq_cuda_kernel_256.cu",
|
||||
)
|
||||
if any(not (module_root / src).is_file() for src in src_files):
|
||||
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
|
||||
self.cache_kernels = None
|
||||
else:
|
||||
try:
|
||||
from .cpp_kernels import cache_autogptq_cuda_256
|
||||
|
||||
self.cache_kernels = cache_autogptq_cuda_256
|
||||
except ImportError:
|
||||
warnings.warn("Failed to import KV cache kernels.")
|
||||
self.cache_kernels = None
|
||||
|
||||
def _attn(
|
||||
self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None
|
||||
):
|
||||
device = query.device
|
||||
if self.use_cache_quantization:
|
||||
qk, qk_scale, qk_zero = key
|
||||
if self.use_cache_kernel and self.cache_kernels is not None:
|
||||
shape = query.shape[:-1] + (qk.shape[-2],)
|
||||
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
|
||||
self.cache_kernels.vecquant8matmul_batched_faster_old(
|
||||
query.contiguous()
|
||||
if query.dtype == torch.float16
|
||||
else query.to(torch.float16).contiguous(),
|
||||
qk.transpose(-1, -2).contiguous(),
|
||||
attn_weights,
|
||||
qk_scale.contiguous()
|
||||
if qk_scale.dtype == torch.float16
|
||||
else qk_scale.to(torch.float16).contiguous(),
|
||||
qk_zero.contiguous()
|
||||
if qk_zero.dtype == torch.float16
|
||||
else qk_zero.to(torch.float16).contiguous(),
|
||||
)
|
||||
# attn_weights = attn_weights.to(query.dtype).contiguous()
|
||||
else:
|
||||
key = dequantize_cache_torch(qk, qk_scale, qk_zero)
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
else:
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
if self.use_cache_quantization:
|
||||
size_temp = value[0].size(-1)
|
||||
else:
|
||||
size_temp = value.size(-1)
|
||||
attn_weights = attn_weights / (size_temp**0.5)
|
||||
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
if causal_mask is not None:
|
||||
attn_weights = torch.where(
|
||||
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
if self.softmax_in_fp32:
|
||||
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
|
||||
else:
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_weights = attn_weights.type(query.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
if self.use_cache_quantization:
|
||||
qv, qv_scale, qv_zero = value
|
||||
if self.use_cache_kernel and self.cache_kernels is not None:
|
||||
shape = attn_weights.shape[:-1] + (query.shape[-1],)
|
||||
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
|
||||
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
|
||||
attn_weights.contiguous()
|
||||
if attn_weights.dtype == torch.float16
|
||||
else attn_weights.to(torch.float16).contiguous(),
|
||||
qv.contiguous(), # dtype: int32
|
||||
attn_output,
|
||||
qv_scale.contiguous()
|
||||
if qv_scale.dtype == torch.float16
|
||||
else qv_scale.to(torch.float16).contiguous(),
|
||||
qv_zero.contiguous()
|
||||
if qv_zero.dtype == torch.float16
|
||||
else qv_zero.to(torch.float16).contiguous(),
|
||||
)
|
||||
if attn_output.dtype != query.dtype:
|
||||
attn_output = attn_output.to(query.dtype)
|
||||
attn_weights = attn_weights.to(query.dtype)
|
||||
else:
|
||||
value = dequantize_cache_torch(qv, qv_scale, qv_zero)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
else:
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
|
@ -323,39 +169,9 @@ class QWenAttention(nn.Module):
|
|||
query = torch.cat(query_list, dim=0)
|
||||
key = torch.cat(key_list, dim=0)
|
||||
|
||||
if self.use_cache_quantization:
|
||||
key = quantize_cache_v(
|
||||
key.permute(0, 2, 1, 3),
|
||||
bits=8,
|
||||
qmin=self.cache_qmin,
|
||||
qmax=self.cache_qmax,
|
||||
)
|
||||
value = quantize_cache_v(
|
||||
value.permute(0, 2, 1, 3),
|
||||
bits=8,
|
||||
qmin=self.cache_qmin,
|
||||
qmax=self.cache_qmax,
|
||||
)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
if self.use_cache_quantization:
|
||||
# use_cache_quantization:
|
||||
# present=((q_key,key_scale,key_zero_point),
|
||||
# (q_value,value_scale,value_zero_point))
|
||||
key = (
|
||||
torch.cat((past_key[0], key[0]), dim=2),
|
||||
torch.cat((past_key[1], key[1]), dim=2),
|
||||
torch.cat((past_key[2], key[2]), dim=2),
|
||||
)
|
||||
value = (
|
||||
torch.cat((past_value[0], value[0]), dim=2),
|
||||
torch.cat((past_value[1], value[1]), dim=2),
|
||||
torch.cat((past_value[2], value[2]), dim=2),
|
||||
)
|
||||
else:
|
||||
# not use_cache_quantization:
|
||||
# present=(key,value)
|
||||
|
||||
key = torch.cat((past_key, key), dim=1)
|
||||
value = torch.cat((past_value, value), dim=1)
|
||||
|
||||
|
@ -364,18 +180,14 @@ class QWenAttention(nn.Module):
|
|||
else:
|
||||
present = None
|
||||
|
||||
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
|
||||
key_size = key.size(1)
|
||||
if key_size > self.seq_length and self.use_logn_attn and not self.training:
|
||||
if self.use_cache_quantization:
|
||||
seq_start = key[0].size(2) - query.size(1)
|
||||
seq_end = key[0].size(2)
|
||||
else:
|
||||
seq_start = key.size(1) - query.size(1)
|
||||
seq_end = key.size(1)
|
||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||
query = query * logn_tensor.expand_as(query)
|
||||
|
||||
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
|
||||
key_size = key.size(1)
|
||||
if query.size(1) == key_size:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
|
||||
|
@ -383,11 +195,9 @@ class QWenAttention(nn.Module):
|
|||
else:
|
||||
causal_mask = None
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
if not self.use_cache_quantization:
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
|
||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
|
||||
if causal_mask is not None:
|
||||
|
@ -400,22 +210,16 @@ class QWenAttention(nn.Module):
|
|||
query, key, value, attn_mask=attention_mask
|
||||
).transpose(1, 2)
|
||||
attn_weight = None
|
||||
else:
|
||||
attn_output, attn_weight = self._attn(
|
||||
query, key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
|
||||
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
|
||||
attn_output = self.c_proj(context_layer)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||
raise ValueError(
|
||||
"Cannot output attentions while using scaled_dot_product_attention"
|
||||
)
|
||||
else:
|
||||
outputs += (attn_weight,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
@ -444,7 +248,6 @@ class QWenBlock(nn.Module):
|
|||
def __init__(self, config):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.bf16 = config.bf16
|
||||
|
||||
self.ln_1 = RMSNorm(
|
||||
hidden_size,
|
||||
|
@ -549,11 +352,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
self.vocab_size = config.vocab_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embed_dim = config.hidden_size
|
||||
self.use_cache_quantization = (
|
||||
self.config.use_cache_quantization
|
||||
if hasattr(self.config, "use_cache_quantization")
|
||||
else False
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||
|
@ -571,7 +369,7 @@ class QWenModel(QWenPreTrainedModel):
|
|||
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
|
||||
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
||||
|
||||
self.is_fp32 = not (config.bf16 or config.fp16)
|
||||
self.is_fp32 = True
|
||||
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config) for i in range(config.num_hidden_layers)]
|
||||
|
@ -650,9 +448,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
if self.use_cache_quantization:
|
||||
past_length = past_key_values[0][0][0].size(2)
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
if position_ids is None:
|
||||
|
@ -682,9 +477,6 @@ class QWenModel(QWenPreTrainedModel):
|
|||
kv_seq_len = hidden_states.size()[1]
|
||||
if past_key_values[0] is not None:
|
||||
# past key values[0][0] shape: bs * seq_len * head_num * dim
|
||||
if self.use_cache_quantization:
|
||||
kv_seq_len += past_key_values[0][0][0].shape[2]
|
||||
else:
|
||||
kv_seq_len += past_key_values[0][0].shape[1]
|
||||
|
||||
if self.training or not self.use_dynamic_ntk:
|
||||
|
@ -796,55 +588,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
assert (
|
||||
config.bf16 + config.fp16 + config.fp32 <= 1
|
||||
), 'Only one of "bf16", "fp16", "fp32" can be true'
|
||||
|
||||
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
|
||||
|
||||
if autoset_precision:
|
||||
if SUPPORT_BF16:
|
||||
logger.warn(
|
||||
"The model is automatically converting to bf16 for faster inference. "
|
||||
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
config.bf16 = True
|
||||
elif SUPPORT_FP16:
|
||||
logger.warn(
|
||||
"The model is automatically converting to fp16 for faster inference. "
|
||||
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
config.fp16 = True
|
||||
else:
|
||||
config.fp32 = True
|
||||
|
||||
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
|
||||
logger.warn(
|
||||
'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
|
||||
logger.warn(
|
||||
"Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster"
|
||||
)
|
||||
if config.fp32:
|
||||
if SUPPORT_BF16:
|
||||
logger.warn(
|
||||
'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
elif SUPPORT_FP16:
|
||||
logger.warn(
|
||||
'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".'
|
||||
)
|
||||
|
||||
self.transformer = QWenModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
if config.bf16:
|
||||
self.transformer.bfloat16()
|
||||
self.lm_head.bfloat16()
|
||||
if config.fp16:
|
||||
self.transformer.half()
|
||||
self.lm_head.half()
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
|
@ -928,13 +674,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||
)
|
||||
|
||||
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
||||
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
# loss_fct = CrossEntropyLoss()
|
||||
# loss = loss_fct(
|
||||
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||
# )
|
||||
# loss.backward()
|
||||
shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
|
@ -948,18 +694,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(
|
||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||
) -> Tuple[Tuple[torch.Tensor]]:
|
||||
return tuple(
|
||||
tuple(
|
||||
past_state.index_select(0, beam_idx.to(past_state.device))
|
||||
for past_state in layer_past
|
||||
)
|
||||
for layer_past in past_key_values
|
||||
)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
|
@ -1171,9 +905,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||
if synced_gpus is None:
|
||||
synced_gpus = False
|
||||
|
||||
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||
self._validate_model_class()
|
||||
|
||||
|
@ -1271,39 +1002,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
generation_config.eos_token_id,
|
||||
)
|
||||
|
||||
# decoder-only models should use left-padding for generation
|
||||
if not self.config.is_encoder_decoder:
|
||||
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
||||
if (
|
||||
generation_config.pad_token_id is not None
|
||||
and len(inputs_tensor.shape) == 2
|
||||
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
|
||||
> 0
|
||||
):
|
||||
logger.warning(
|
||||
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
||||
)
|
||||
|
||||
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
||||
# if model is encoder decoder encoder_outputs are created
|
||||
# and added to `model_kwargs`
|
||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
||||
inputs_tensor, model_kwargs, model_input_name
|
||||
)
|
||||
|
||||
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||
if self.config.is_encoder_decoder:
|
||||
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||
batch_size=batch_size,
|
||||
model_input_name=model_input_name,
|
||||
model_kwargs=model_kwargs,
|
||||
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||
bos_token_id=generation_config.bos_token_id,
|
||||
device=inputs_tensor.device,
|
||||
)
|
||||
else:
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
|
@ -1378,7 +1078,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences,
|
||||
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||
is_encoder_decoder=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
|
@ -1483,19 +1183,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
|
||||
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||
encoder_attentions = (
|
||||
model_kwargs["encoder_outputs"].get("attentions")
|
||||
if output_attentions
|
||||
else None
|
||||
)
|
||||
encoder_hidden_states = (
|
||||
model_kwargs["encoder_outputs"].get("hidden_states")
|
||||
if output_hidden_states
|
||||
else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||
|
@ -1504,16 +1191,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
this_peer_finished = False # used by synced_gpus only
|
||||
# auto-regressive generation
|
||||
while True:
|
||||
# if synced_gpus:
|
||||
# # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||
# # The following logic allows an early break if all peers finished generating their sequence
|
||||
# this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||
# # send 0.0 if we finished, 1.0 otherwise
|
||||
# dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||
# # did all peers finish? the reduced sum will be 0.0 then
|
||||
# if this_peer_finished_flag.item() == 0.0:
|
||||
# break
|
||||
|
||||
# prepare model inputs
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||
|
||||
|
@ -1525,9 +1202,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
if synced_gpus and this_peer_finished:
|
||||
continue # don't waste resources running the code we don't need
|
||||
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# pre-process distribution
|
||||
|
@ -1539,20 +1213,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
if output_scores:
|
||||
scores += (next_token_scores,)
|
||||
if output_attentions:
|
||||
decoder_attentions += (
|
||||
(outputs.decoder_attentions,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.attentions,)
|
||||
)
|
||||
if self.config.is_encoder_decoder:
|
||||
cross_attentions += (outputs.cross_attentions,)
|
||||
decoder_attentions += (outputs.attentions,)
|
||||
|
||||
if output_hidden_states:
|
||||
decoder_hidden_states += (
|
||||
(outputs.decoder_hidden_states,)
|
||||
if self.config.is_encoder_decoder
|
||||
else (outputs.hidden_states,)
|
||||
)
|
||||
decoder_hidden_states += (outputs.hidden_states,)
|
||||
|
||||
# sample
|
||||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
@ -1573,7 +1237,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
if streamer is not None:
|
||||
streamer.put(next_tokens.cpu())
|
||||
model_kwargs = self._update_model_kwargs_for_generation(
|
||||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||
outputs, model_kwargs, is_encoder_decoder=False
|
||||
)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
|
@ -1592,7 +1256,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
if stopping_criteria(input_ids, scores):
|
||||
this_peer_finished = True
|
||||
|
||||
if this_peer_finished and not synced_gpus:
|
||||
if this_peer_finished:
|
||||
break
|
||||
|
||||
if streamer is not None:
|
||||
|
@ -1600,44 +1264,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|||
|
||||
return input_ids
|
||||
|
||||
# def backward(
|
||||
# self,
|
||||
# tokenizer,
|
||||
# query: str,
|
||||
# ):
|
||||
# inputs = tokenizer.build_chat_input(query, history=[], role="user")
|
||||
# inputs = inputs.to(next(self.parameters()).device)
|
||||
|
||||
# generation_config = copy.deepcopy(self.generation_config)
|
||||
# inputs_tensor = inputs["input_ids"]
|
||||
# input_ids = inputs_tensor.repeat_interleave(
|
||||
# generation_config.num_return_sequences, dim=0
|
||||
# )
|
||||
|
||||
# input_ids_in = input_ids
|
||||
# batch_size, seq_length = input_ids_in.shape
|
||||
# position_ids_in = (
|
||||
# torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||
# .unsqueeze(0)
|
||||
# .repeat(batch_size, 1)
|
||||
# )
|
||||
# model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
|
||||
|
||||
# probs, next_tokens = self.transformer(
|
||||
# **model_inputs,
|
||||
# output_hidden_states=None,
|
||||
# tokenizer=tokenizer,
|
||||
# )
|
||||
|
||||
# next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
# # probs_target = probs
|
||||
# # probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1
|
||||
|
||||
# loss = probs[0, next_tokens]
|
||||
# loss.backward()
|
||||
|
||||
# return loss
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, base=10000):
|
||||
|
@ -1703,14 +1329,6 @@ def apply_rotary_pos_emb(t, freqs):
|
|||
rot_dim = freqs[0].shape[-1]
|
||||
cos, sin = freqs
|
||||
t_float = t.float()
|
||||
if apply_rotary_emb_func is not None and t.is_cuda:
|
||||
# apply_rotary_emb in flash_attn requires cos/sin to be of
|
||||
# shape (seqlen, rotary_dim / 2) and apply rotary embedding
|
||||
# to the first rotary_dim of the input
|
||||
cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
||||
sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
|
||||
return apply_rotary_emb_func(t_float, cos, sin).type_as(t)
|
||||
else:
|
||||
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
|
||||
t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
|
||||
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
|
||||
|
@ -1726,8 +1344,5 @@ class RMSNorm(torch.nn.Module):
|
|||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
if rms_norm is not None and x.is_cuda:
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
else:
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
|
Loading…
Reference in New Issue