Update qwen model.

This commit is contained in:
Colin 2024-01-07 16:22:41 +08:00
parent f6538c1111
commit 255a2ff71c
1 changed files with 28 additions and 228 deletions

View File

@ -50,12 +50,8 @@ from qwen_generation_utils import (
StopWordsLogitsProcessor,
)
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "qwen"
_CONFIG_FOR_DOC = "QWenConfig"
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
_ERROR_BAD_CHAT_FORMAT = """\
@ -71,55 +67,8 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
向model.chat()传入参数stream的用法可能存在Bug该用法已被废弃将在未来被移除请使用model.chat_stream(...)代替model.chat(..., stream=True)
"""
_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\
We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained).
检测到您的模型已激活了flash attention支持但正在执行CPU运算任务如使用flash attention请您确认模型输入已经传到GPU上如果您确认要执行CPU运算请您在载入模型调用AutoModelForCausalLM.from_pretrained按照readme说法指定device_map="cpu"以禁用flash attention
"""
apply_rotary_emb_func = None
rms_norm = None
flash_attn_unpadded_func = None
flash_attn_func = None
def _import_flash_attn():
global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func, flash_attn_func
try:
from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
apply_rotary_emb_func = __apply_rotary_emb_func
except ImportError:
logger.warn(
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
)
try:
from flash_attn.ops.rms_norm import rms_norm as __rms_norm
rms_norm = __rms_norm
except ImportError:
logger.warn(
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
)
try:
import flash_attn
_flash_attn_func = None
if not hasattr(flash_attn, '__version__'):
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
else:
if int(flash_attn.__version__.split(".")[0]) >= 2:
if int(flash_attn.__version__.split(".")[1]) >= 1:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
flash_attn_unpadded_func = __flash_attn_unpadded_func
flash_attn_func = _flash_attn_func
except ImportError:
logger.warn(
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention"
)
def quantize_cache_v(fdata, bits, qmax, qmin):
# b, s, head, h-dim->b, head, s, h-dim
@ -147,104 +96,6 @@ def dequantize_cache_torch(qdata, scale, zero):
data = scale * (qdata - zero)
return data
class FlashSelfAttention(torch.nn.Module):
def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
):
super().__init__()
assert flash_attn_unpadded_func is not None, (
"Please install FlashAttention first, " "e.g., with pip install flash-attn"
)
assert (
rearrange is not None
), "Please install einops first, e.g., with pip install einops"
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def unpad_input(self, hidden_states, attention_mask):
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_states = hidden_states[indices]
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
def pad_input(self, hidden_states, indices, batch, seqlen):
output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
dtype=hidden_states.dtype)
output[indices] = hidden_states
return rearrange(output, '(b s) ... -> b s ...', b=batch)
def forward(self, q, k, v, attention_mask=None):
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
assert all((i.is_cuda for i in (q, k, v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
seqlen_out = seqlen_q
if flash_attn_func is not None and batch_size == 1:
dropout_p = self.dropout_p if self.training else 0
output = flash_attn_func(q, k, v, dropout_p, softmax_scale=self.softmax_scale, causal=self.causal)
return output
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q.device,
)
if batch_size > 1 and attention_mask is not None:
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
if q.size(0) == v.size(0):
q = q[indices_k]
cu_seqlens_q = cu_seqlens_k
seqlen_q = seqlen_k
v = v[indices_k]
else:
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=q.device,
)
if self.training:
assert seqlen_k == seqlen_q
is_causal = self.causal
dropout_p = self.dropout_p
else:
is_causal = seqlen_q == seqlen_k
dropout_p = 0
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqlen_q,
seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale,
causal=is_causal,
)
if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
else:
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
output = output.view(new_shape)
return output
class QWenAttention(nn.Module):
def __init__(self, config):
super().__init__()
@ -257,7 +108,6 @@ class QWenAttention(nn.Module):
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.use_flash_attn = config.use_flash_attn
self.scale_attn_weights = True
self.projection_size = config.kv_channels * config.num_attention_heads
@ -274,14 +124,7 @@ class QWenAttention(nn.Module):
)
self.is_fp32 = not (config.bf16 or config.fp16)
if (
self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
):
self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=config.attn_dropout_prob
)
self.bf16 = config.bf16
self.use_dynamic_ntk = config.use_dynamic_ntk
@ -490,52 +333,35 @@ class QWenAttention(nn.Module):
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
query = query * logn_tensor.expand_as(query)
if (
self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
and query.is_cuda
):
q, k, v = query, key, value
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
if query.size(1) == key_size:
causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
).view(1, 1, key_size, key_size)
else:
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
if query.size(1) == key_size:
causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device)
).view(1, 1, key_size, key_size)
else:
causal_mask = None
query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization:
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
if (
causal_mask is None
and self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
and not query.is_cuda
):
raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
causal_mask = None
query = query.permute(0, 2, 1, 3)
if not self.use_cache_quantization:
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
if not self.use_cache_quantization and SUPPORT_TORCH2:
if attention_mask is not None:
attention_mask = attention_mask.expand(
-1, -1, causal_mask.size(2), -1
)
if causal_mask is not None:
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
else:
attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
).transpose(1, 2)
attn_weight = None
else:
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
if not self.use_cache_quantization and SUPPORT_TORCH2:
if attention_mask is not None:
attention_mask = attention_mask.expand(
-1, -1, causal_mask.size(2), -1
)
if causal_mask is not None:
attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min)
else:
attention_mask = causal_mask
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
).transpose(1, 2)
attn_weight = None
else:
attn_output, attn_weight = self._attn(
query, key, value, causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
)
@ -544,13 +370,7 @@ class QWenAttention(nn.Module):
outputs = (attn_output, present)
if output_attentions:
if (
self.use_flash_attn
and flash_attn_unpadded_func is not None
and not self.is_fp32
):
raise ValueError("Cannot output attentions while using flash-attn")
elif not self.use_cache_quantization and SUPPORT_TORCH2:
if not self.use_cache_quantization and SUPPORT_TORCH2:
raise ValueError("Cannot output attentions while using scaled_dot_product_attention")
else:
outputs += (attn_weight,)
@ -711,7 +531,6 @@ class QWenModel(QWenPreTrainedModel):
)
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.use_flash_attn = config.use_flash_attn
self.is_fp32 = not (config.bf16 or config.fp16)
self.h = nn.ModuleList(
@ -967,18 +786,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.use_flash_attn == "auto":
if config.bf16 or config.fp16:
logger.warn("Try importing flash-attention for faster inference...")
config.use_flash_attn = True
else:
config.use_flash_attn = False
if config.use_flash_attn and config.fp32:
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
if config.use_flash_attn:
_import_flash_attn()
self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@ -1073,19 +880,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
)
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
# shift_logits = lm_logits[..., :-1, :].contiguous()
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
# )
# optimizer = torch.optim.Adam(self.parameters(), lr=2e-5)
# # optimizer = torch.optim.SGD(self.parameters(),lr=0.001)
# # pa = self.transformer.parameters()
# loss.backward()
# # optimizer.step()
if not return_dict: