Re format qwen.

This commit is contained in:
Colin 2024-01-07 21:54:37 +08:00
parent aa2d3b96c4
commit 4c0991a409
1 changed files with 59 additions and 204 deletions

View File

@ -50,6 +50,7 @@ logger = logging.get_logger(__name__)
_SENTINEL = object()
class QWenAttention(nn.Module):
def __init__(self, config):
super().__init__()
@ -67,33 +68,22 @@ class QWenAttention(nn.Module):
self.projection_size = config.kv_channels * config.num_attention_heads
assert self.projection_size % config.num_attention_heads == 0
self.hidden_size_per_attention_head = (
self.projection_size // config.num_attention_heads
)
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
self.c_proj = nn.Linear(
config.hidden_size, self.projection_size, bias=not config.no_bias
)
self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias=not config.no_bias)
self.use_dynamic_ntk = config.use_dynamic_ntk
self.use_logn_attn = config.use_logn_attn
logn_list = [
math.log(i, self.seq_length) if i > self.seq_length else 1
for i in range(1, 32768)
]
logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
logn_tensor = torch.tensor(logn_list)[None, :, None, None]
self.register_buffer("logn_tensor", logn_tensor, persistent=False)
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
self.softmax_in_fp32 = (
config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
)
self.use_cache_kernel = (
config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
)
self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False
self.use_cache_kernel = config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False
cache_dtype = torch.float
self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype)
self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype)
@ -114,10 +104,6 @@ class QWenAttention(nn.Module):
rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
mixed_x_layer = self.c_attn(hidden_states)
@ -146,9 +132,7 @@ class QWenAttention(nn.Module):
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query_list += [
apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)
]
query_list += [apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)]
key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)]
query = torch.cat(query_list, dim=0)
key = torch.cat(key_list, dim=0)
@ -173,9 +157,9 @@ class QWenAttention(nn.Module):
key_size = key.size(1)
if query.size(1) == key_size:
causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
).view(1, 1, key_size, key_size)
causal_mask = torch.tril(torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)).view(
1, 1, key_size, key_size
)
else:
causal_mask = None
query = query.permute(0, 2, 1, 3)
@ -185,24 +169,16 @@ class QWenAttention(nn.Module):
if attention_mask is not None:
attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1)
if causal_mask is not None:
attention_mask = attention_mask.masked_fill(
~causal_mask, torch.finfo(query.dtype).min
)
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
else:
attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
).transpose(1, 2)
attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2)
context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(context_layer)
outputs = (attn_output, present)
if output_attentions:
raise ValueError(
"Cannot output attentions while using scaled_dot_product_attention"
)
return outputs
@ -210,12 +186,8 @@ class QWenAttention(nn.Module):
class QWenMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
self.w2 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
self.w1 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
self.w2 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)
ff_dim_in = config.intermediate_size // 2
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
@ -263,10 +235,9 @@ class QWenBlock(nn.Module):
rotary_pos_emb_list,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
outputs = attn_outputs[1:]
@ -299,8 +270,8 @@ class QWenPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
class QWenModel(QWenPreTrainedModel):
class QWenModel(QWenPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.vocab_size = config.vocab_size
@ -322,9 +293,7 @@ class QWenModel(QWenPreTrainedModel):
dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.h = nn.ModuleList(
[QWenBlock(config) for i in range(config.num_hidden_layers)]
)
self.h = nn.ModuleList([QWenBlock(config) for i in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(
self.embed_dim,
eps=config.layer_norm_epsilon,
@ -348,19 +317,13 @@ class QWenModel(QWenPreTrainedModel):
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None
output_attentions: Optional[bool] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
@ -399,12 +362,7 @@ class QWenModel(QWenPreTrainedModel):
else:
ntk_alpha_list = []
if attention_mask is not None and kv_seq_len > self.seq_length:
true_seq_lens = (
attention_mask.squeeze(1)
.squeeze(1)
.eq(0)
.sum(dim=-1, dtype=torch.int32)
)
true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32)
for i in range(hidden_states.size()[0]):
true_seq_len = true_seq_lens[i].item()
ntk_alpha = self.get_ntk_alpha(true_seq_len)
@ -413,10 +371,7 @@ class QWenModel(QWenPreTrainedModel):
ntk_alpha = self.get_ntk_alpha(kv_seq_len)
ntk_alpha_list.append(ntk_alpha)
self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
rotary_pos_emb_list = [
self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
for ntk_alpha in ntk_alpha_list
]
rotary_pos_emb_list = [self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list]
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
@ -425,7 +380,6 @@ class QWenModel(QWenPreTrainedModel):
all_self_attentions = () if output_attentions else None
all_hidden_states = None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(
hidden_states,
layer_past=layer_past,
@ -443,9 +397,7 @@ class QWenModel(QWenPreTrainedModel):
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
@ -458,7 +410,6 @@ class QWenModel(QWenPreTrainedModel):
class QWenLMHeadModel(QWenPreTrainedModel):
def __init__(self, config):
super().__init__(config)
@ -466,9 +417,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
@ -502,10 +451,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None
output_attentions: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
@ -515,7 +462,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions
output_attentions=output_attentions,
)
hidden_states = transformer_outputs[0]
@ -527,9 +474,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
# shift_logits = lm_logits[..., :-1, :].contiguous()
@ -558,11 +503,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Tuple[str, HistoryType]:
generation_config = (
generation_config
if generation_config is not None
else self.generation_config
)
generation_config = generation_config if generation_config is not None else self.generation_config
assert stream is _SENTINEL
assert generation_config.chat_format == "chatml"
@ -587,9 +528,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(
get_stop_words_ids(generation_config.chat_format, tokenizer)
)
stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer))
input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate(
input_ids,
@ -622,19 +561,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = (
generation_config
if generation_config is not None
else self.generation_config
)
generation_config = generation_config if generation_config is not None else self.generation_config
# Process stop_words_ids.
stop_words_ids = kwargs.pop("stop_words_ids", None)
@ -671,9 +604,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
@ -685,26 +616,15 @@ class QWenLMHeadModel(QWenPreTrainedModel):
self._validate_model_class()
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(
**kwargs
) # All unused kwargs must be model kwargs
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# 2. Set generation parameters if not already defined
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if (
generation_config.pad_token_id is None
and generation_config.eos_token_id is not None
):
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask", None) is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
@ -713,9 +633,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning(
f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
)
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id
# 3. Define model inputs
@ -735,46 +653,27 @@ class QWenLMHeadModel(QWenPreTrainedModel):
else:
model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(
inspect.signature(self.forward).parameters.keys()
)
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if (
model_kwargs.get("attention_mask", None) is None
and requires_attention_mask
and accepts_attention_mask
):
model_kwargs[
"attention_mask"
] = self._prepare_attention_mask_for_generation(
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor,
generation_config.pad_token_id,
generation_config.eos_token_id,
)
# 5. Prepare `input_ids` which will be used for auto-regressive generation
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
if streamer is not None:
streamer.put(input_ids.cpu())
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = (
kwargs.get("max_length") is None
and generation_config.max_length is not None
)
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_length
)
self._validate_generated_length(
generation_config, input_ids_length, has_default_max_length
)
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 8. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
@ -835,53 +734,25 @@ class QWenLMHeadModel(QWenPreTrainedModel):
**model_kwargs,
):
# init values
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_warper = (
logits_warper if logits_warper is not None else LogitsProcessorList()
)
pad_token_id = (
pad_token_id
if pad_token_id is not None
else self.generation_config.pad_token_id
)
eos_token_id = (
eos_token_id
if eos_token_id is not None
else self.generation_config.eos_token_id
)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = (
torch.tensor(eos_token_id).to(input_ids.device)
if eos_token_id is not None
else None
)
output_scores = (
output_scores
if output_scores is not None
else self.generation_config.output_scores
)
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions
if output_attentions is not None
else self.generation_config.output_attentions
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
# init attention / hidden states / scores tuples
scores = None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
@ -890,10 +761,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
output_attentions=output_attentions
)
outputs = self(**model_inputs, output_attentions=output_attentions)
next_token_logits = outputs.logits[:, -1, :]
@ -908,27 +776,19 @@ class QWenLMHeadModel(QWenPreTrainedModel):
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False
)
model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished
@ -964,11 +824,7 @@ class RotaryEmbedding(torch.nn.Module):
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (
base
** (
torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
/ self.dim
)
base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim)
)
self._seq_len_cached = max(2 * seqlen, 16)
self._ntk_alpha_cached = ntk_alpha
@ -998,7 +854,6 @@ def _rotate_half(x):
def apply_rotary_pos_emb(t, freqs):
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_float = t.float()