Refine model of qwen.

This commit is contained in:
Colin 2024-01-07 16:53:53 +08:00
parent 611396b656
commit 90cb0fe236
1 changed files with 51 additions and 436 deletions

View File

@ -36,12 +36,6 @@ except ImportError:
from torch import nn from torch import nn
SUPPORT_CUDA = torch.cuda.is_available() SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
SUPPORT_TORCH2 = (
hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2
)
from configuration_qwen import QWenConfig from configuration_qwen import QWenConfig
from qwen_generation_utils import ( from qwen_generation_utils import (
@ -69,37 +63,6 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
向model.chat()传入参数stream的用法可能存在Bug该用法已被废弃将在未来被移除请使用model.chat_stream(...)代替model.chat(..., stream=True) 向model.chat()传入参数stream的用法可能存在Bug该用法已被废弃将在未来被移除请使用model.chat_stream(...)代替model.chat(..., stream=True)
""" """
apply_rotary_emb_func = None
rms_norm = None
def quantize_cache_v(fdata, bits, qmax, qmin):
# b, s, head, h-dim->b, head, s, h-dim
qtype = torch.uint8
device = fdata.device
shape = fdata.shape
fdata_cal = torch.flatten(fdata, 2)
fmax = torch.amax(fdata_cal, dim=-1, keepdim=True)
fmin = torch.amin(fdata_cal, dim=-1, keepdim=True)
# Compute params
if qmax.device != fmax.device:
qmax = qmax.to(device)
qmin = qmin.to(device)
scale = (fmax - fmin) / (qmax - qmin)
zero = qmin - fmin / scale
scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous()
# Quantize
res_data = fdata / scale + zero
qdata = torch.clamp(res_data, qmin, qmax).to(qtype)
return qdata.contiguous(), scale, zero
def dequantize_cache_torch(qdata, scale, zero):
data = scale * (qdata - zero)
return data
class QWenAttention(nn.Module): class QWenAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
@ -128,9 +91,7 @@ class QWenAttention(nn.Module):
config.hidden_size, self.projection_size, bias=not config.no_bias config.hidden_size, self.projection_size, bias=not config.no_bias
) )
self.is_fp32 = not (config.bf16 or config.fp16) self.is_fp32 = True
self.bf16 = config.bf16
self.use_dynamic_ntk = config.use_dynamic_ntk self.use_dynamic_ntk = config.use_dynamic_ntk
self.use_logn_attn = config.use_logn_attn self.use_logn_attn = config.use_logn_attn
@ -146,128 +107,13 @@ class QWenAttention(nn.Module):
self.softmax_in_fp32 = ( self.softmax_in_fp32 = (
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
) )
self.use_cache_quantization = (
config.use_cache_quantization
if hasattr(config, "use_cache_quantization")
else False
)
self.use_cache_kernel = ( self.use_cache_kernel = (
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
) )
cache_dtype = torch.float cache_dtype = torch.float
if self.bf16:
cache_dtype = torch.bfloat16
elif config.fp16:
cache_dtype = torch.float16
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
if config.use_cache_quantization and config.use_cache_kernel:
# pre check if the support files existing
module_root = pathlib.Path(__file__).parent
src_files = (
"cache_autogptq_cuda_256.cpp",
"cache_autogptq_cuda_kernel_256.cu",
)
if any(not (module_root / src).is_file() for src in src_files):
warnings.warn("KV cache kernel source files (.cpp and .cu) not found.")
self.cache_kernels = None
else:
try:
from .cpp_kernels import cache_autogptq_cuda_256
self.cache_kernels = cache_autogptq_cuda_256
except ImportError:
warnings.warn("Failed to import KV cache kernels.")
self.cache_kernels = None
def _attn(
self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None
):
device = query.device
if self.use_cache_quantization:
qk, qk_scale, qk_zero = key
if self.use_cache_kernel and self.cache_kernels is not None:
shape = query.shape[:-1] + (qk.shape[-2],)
attn_weights = torch.zeros(shape, dtype=torch.float16, device=device)
self.cache_kernels.vecquant8matmul_batched_faster_old(
query.contiguous()
if query.dtype == torch.float16
else query.to(torch.float16).contiguous(),
qk.transpose(-1, -2).contiguous(),
attn_weights,
qk_scale.contiguous()
if qk_scale.dtype == torch.float16
else qk_scale.to(torch.float16).contiguous(),
qk_zero.contiguous()
if qk_zero.dtype == torch.float16
else qk_zero.to(torch.float16).contiguous(),
)
# attn_weights = attn_weights.to(query.dtype).contiguous()
else:
key = dequantize_cache_torch(qk, qk_scale, qk_zero)
attn_weights = torch.matmul(query, key.transpose(-1, -2))
else:
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
if self.use_cache_quantization:
size_temp = value[0].size(-1)
else:
size_temp = value.size(-1)
attn_weights = attn_weights / (size_temp**0.5)
mask_value = torch.finfo(attn_weights.dtype).min
if causal_mask is not None:
attn_weights = torch.where(
causal_mask, attn_weights.to(attn_weights.dtype), mask_value
)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if self.softmax_in_fp32:
attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(query.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
if self.use_cache_quantization:
qv, qv_scale, qv_zero = value
if self.use_cache_kernel and self.cache_kernels is not None:
shape = attn_weights.shape[:-1] + (query.shape[-1],)
attn_output = torch.zeros(shape, dtype=torch.float16, device=device)
self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old(
attn_weights.contiguous()
if attn_weights.dtype == torch.float16
else attn_weights.to(torch.float16).contiguous(),
qv.contiguous(), # dtype: int32
attn_output,
qv_scale.contiguous()
if qv_scale.dtype == torch.float16
else qv_scale.to(torch.float16).contiguous(),
qv_zero.contiguous()
if qv_zero.dtype == torch.float16
else qv_zero.to(torch.float16).contiguous(),
)
if attn_output.dtype != query.dtype:
attn_output = attn_output.to(query.dtype)
attn_weights = attn_weights.to(query.dtype)
else:
value = dequantize_cache_torch(qv, qv_scale, qv_zero)
attn_output = torch.matmul(attn_weights, value)
else:
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size): def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape) tensor = tensor.view(new_shape)
@ -323,59 +169,25 @@ class QWenAttention(nn.Module):
query = torch.cat(query_list, dim=0) query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0) key = torch.cat(key_list, dim=0)
if self.use_cache_quantization:
key = quantize_cache_v(
key.permute(0, 2, 1, 3),
bits=8,
qmin=self.cache_qmin,
qmax=self.cache_qmax,
)
value = quantize_cache_v(
value.permute(0, 2, 1, 3),
bits=8,
qmin=self.cache_qmin,
qmax=self.cache_qmax,
)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1] past_key, past_value = layer_past[0], layer_past[1]
if self.use_cache_quantization:
# use_cache_quantization: key = torch.cat((past_key, key), dim=1)
# present=((q_key,key_scale,key_zero_point), value = torch.cat((past_value, value), dim=1)
# (q_value,value_scale,value_zero_point))
key = (
torch.cat((past_key[0], key[0]), dim=2),
torch.cat((past_key[1], key[1]), dim=2),
torch.cat((past_key[2], key[2]), dim=2),
)
value = (
torch.cat((past_value[0], value[0]), dim=2),
torch.cat((past_value[1], value[1]), dim=2),
torch.cat((past_value[2], value[2]), dim=2),
)
else:
# not use_cache_quantization:
# present=(key,value)
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache: if use_cache:
present = (key, value) present = (key, value)
else: else:
present = None present = None
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) key_size = key.size(1)
if key_size > self.seq_length and self.use_logn_attn and not self.training: if key_size > self.seq_length and self.use_logn_attn and not self.training:
if self.use_cache_quantization: seq_start = key.size(1) - query.size(1)
seq_start = key[0].size(2) - query.size(1) seq_end = key.size(1)
seq_end = key[0].size(2)
else:
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
query = query * logn_tensor.expand_as(query) query = query * logn_tensor.expand_as(query)
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) key_size = key.size(1)
if query.size(1) == key_size: if query.size(1) == key_size:
causal_mask = torch.tril( causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
@ -383,39 +195,31 @@ class QWenAttention(nn.Module):
else: else:
causal_mask = None causal_mask = None
query = query.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization: key = key.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3) value = value.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
if not self.use_cache_quantization and SUPPORT_TORCH2: if attention_mask is not None:
if attention_mask is not None: attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1) if causal_mask is not None:
if causal_mask is not None: attention_mask = attention_mask.masked_fill(
attention_mask = attention_mask.masked_fill( ~causal_mask, torch.finfo(query.dtype).min
~causal_mask, torch.finfo(query.dtype).min )
)
else:
attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
).transpose(1, 2)
attn_weight = None
else: else:
attn_output, attn_weight = self._attn( attention_mask = causal_mask
query, key, value, causal_mask, attention_mask, head_mask attn_output = F.scaled_dot_product_attention(
) query, key, value, attn_mask=attention_mask
).transpose(1, 2)
attn_weight = None
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer) attn_output = self.c_proj(context_layer)
outputs = (attn_output, present) outputs = (attn_output, present)
if output_attentions: if output_attentions:
if not self.use_cache_quantization and SUPPORT_TORCH2: raise ValueError(
raise ValueError( "Cannot output attentions while using scaled_dot_product_attention"
"Cannot output attentions while using scaled_dot_product_attention" )
)
else:
outputs += (attn_weight,)
return outputs return outputs
@ -444,7 +248,6 @@ class QWenBlock(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.bf16 = config.bf16
self.ln_1 = RMSNorm( self.ln_1 = RMSNorm(
hidden_size, hidden_size,
@ -549,11 +352,6 @@ class QWenModel(QWenPreTrainedModel):
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.use_cache_quantization = (
self.config.use_cache_quantization
if hasattr(self.config, "use_cache_quantization")
else False
)
self.gradient_checkpointing = False self.gradient_checkpointing = False
self.use_dynamic_ntk = config.use_dynamic_ntk self.use_dynamic_ntk = config.use_dynamic_ntk
@ -571,7 +369,7 @@ class QWenModel(QWenPreTrainedModel):
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.is_fp32 = not (config.bf16 or config.fp16) self.is_fp32 = True
self.h = nn.ModuleList( self.h = nn.ModuleList(
[QWenBlock(config) for i in range(config.num_hidden_layers)] [QWenBlock(config) for i in range(config.num_hidden_layers)]
@ -651,10 +449,7 @@ class QWenModel(QWenPreTrainedModel):
past_length = 0 past_length = 0
past_key_values = tuple([None] * len(self.h)) past_key_values = tuple([None] * len(self.h))
else: else:
if self.use_cache_quantization: past_length = past_key_values[0][0].size(-2)
past_length = past_key_values[0][0][0].size(2)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange( position_ids = torch.arange(
past_length, past_length,
@ -682,10 +477,7 @@ class QWenModel(QWenPreTrainedModel):
kv_seq_len = hidden_states.size()[1] kv_seq_len = hidden_states.size()[1]
if past_key_values[0] is not None: if past_key_values[0] is not None:
# past key values[0][0] shape: bs * seq_len * head_num * dim # past key values[0][0] shape: bs * seq_len * head_num * dim
if self.use_cache_quantization: kv_seq_len += past_key_values[0][0].shape[1]
kv_seq_len += past_key_values[0][0][0].shape[2]
else:
kv_seq_len += past_key_values[0][0].shape[1]
if self.training or not self.use_dynamic_ntk: if self.training or not self.use_dynamic_ntk:
ntk_alpha_list = [1.0] ntk_alpha_list = [1.0]
@ -796,55 +588,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), 'Only one of "bf16", "fp16", "fp32" can be true'
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".'
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn(
'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".'
)
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn(
"Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster"
)
if config.fp32:
if SUPPORT_BF16:
logger.warn(
'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".'
)
elif SUPPORT_FP16:
logger.warn(
'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".'
)
self.transformer = QWenModel(config) self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init() self.post_init()
def get_output_embeddings(self): def get_output_embeddings(self):
@ -928,13 +674,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
) )
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64) shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
# shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
# loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# loss = loss_fct( loss = loss_fct(
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
# ) )
# loss.backward() loss.backward()
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
@ -948,18 +694,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past_key_values
)
def chat( def chat(
self, self,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
@ -1171,9 +905,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
negative_prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]: ) -> Union[GenerateOutput, torch.LongTensor]:
if synced_gpus is None:
synced_gpus = False
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class() self._validate_model_class()
@ -1271,44 +1002,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config.eos_token_id, generation_config.eos_token_id,
) )
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config.pad_token_id is not None
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
> 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation # 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( input_ids = (
batch_size=batch_size, inputs_tensor
model_input_name=model_input_name, if model_input_name == "input_ids"
model_kwargs=model_kwargs, else model_kwargs.pop("input_ids")
decoder_start_token_id=generation_config.decoder_start_token_id, )
bos_token_id=generation_config.bos_token_id,
device=inputs_tensor.device,
)
else:
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
if streamer is not None: if streamer is not None:
streamer.put(input_ids.cpu()) streamer.put(input_ids.cpu())
@ -1378,7 +1078,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids, input_ids=input_ids,
expand_size=generation_config.num_return_sequences, expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=False,
**model_kwargs, **model_kwargs,
) )
@ -1483,19 +1183,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
() if (return_dict_in_generate and output_hidden_states) else None () if (return_dict_in_generate and output_hidden_states) else None
) )
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = (
model_kwargs["encoder_outputs"].get("attentions")
if output_attentions
else None
)
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states")
if output_hidden_states
else None
)
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = torch.ones( unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device input_ids.shape[0], dtype=torch.long, device=input_ids.device
@ -1504,16 +1191,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
# auto-regressive generation # auto-regressive generation
while True: while True:
# if synced_gpus:
# # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# # The following logic allows an early break if all peers finished generating their sequence
# this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# # send 0.0 if we finished, 1.0 otherwise
# dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# # did all peers finish? the reduced sum will be 0.0 then
# if this_peer_finished_flag.item() == 0.0:
# break
# prepare model inputs # prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@ -1525,9 +1202,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :] next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution # pre-process distribution
@ -1539,20 +1213,10 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if output_scores: if output_scores:
scores += (next_token_scores,) scores += (next_token_scores,)
if output_attentions: if output_attentions:
decoder_attentions += ( decoder_attentions += (outputs.attentions,)
(outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (outputs.hidden_states,)
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# sample # sample
probs = nn.functional.softmax(next_token_scores, dim=-1) probs = nn.functional.softmax(next_token_scores, dim=-1)
@ -1573,7 +1237,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if streamer is not None: if streamer is not None:
streamer.put(next_tokens.cpu()) streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder outputs, model_kwargs, is_encoder_decoder=False
) )
# if eos_token was found in one sentence, set sentence to finished # if eos_token was found in one sentence, set sentence to finished
@ -1592,7 +1256,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
if stopping_criteria(input_ids, scores): if stopping_criteria(input_ids, scores):
this_peer_finished = True this_peer_finished = True
if this_peer_finished and not synced_gpus: if this_peer_finished:
break break
if streamer is not None: if streamer is not None:
@ -1600,44 +1264,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
return input_ids return input_ids
# def backward(
# self,
# tokenizer,
# query: str,
# ):
# inputs = tokenizer.build_chat_input(query, history=[], role="user")
# inputs = inputs.to(next(self.parameters()).device)
# generation_config = copy.deepcopy(self.generation_config)
# inputs_tensor = inputs["input_ids"]
# input_ids = inputs_tensor.repeat_interleave(
# generation_config.num_return_sequences, dim=0
# )
# input_ids_in = input_ids
# batch_size, seq_length = input_ids_in.shape
# position_ids_in = (
# torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
# .unsqueeze(0)
# .repeat(batch_size, 1)
# )
# model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
# probs, next_tokens = self.transformer(
# **model_inputs,
# output_hidden_states=None,
# tokenizer=tokenizer,
# )
# next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# # probs_target = probs
# # probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1
# loss = probs[0, next_tokens]
# loss.backward()
# return loss
class RotaryEmbedding(torch.nn.Module): class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000): def __init__(self, dim, base=10000):
@ -1703,17 +1329,9 @@ def apply_rotary_pos_emb(t, freqs):
rot_dim = freqs[0].shape[-1] rot_dim = freqs[0].shape[-1]
cos, sin = freqs cos, sin = freqs
t_float = t.float() t_float = t.float()
if apply_rotary_emb_func is not None and t.is_cuda: t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
# apply_rotary_emb in flash_attn requires cos/sin to be of t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
# shape (seqlen, rotary_dim / 2) and apply rotary embedding return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
# to the first rotary_dim of the input
cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2]
sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2]
return apply_rotary_emb_func(t_float, cos, sin).type_as(t)
else:
t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:]
t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin)
return torch.cat((t_rot, t_pass), dim=-1).type_as(t)
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
@ -1726,8 +1344,5 @@ class RMSNorm(torch.nn.Module):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x): def forward(self, x):
if rms_norm is not None and x.is_cuda: output = self._norm(x.float()).type_as(x)
return rms_norm(x, self.weight, self.eps) return output * self.weight
else:
output = self._norm(x.float()).type_as(x)
return output * self.weight