Update qwen model.
This commit is contained in:
parent
f6538c1111
commit
255a2ff71c
|
@ -50,12 +50,8 @@ from qwen_generation_utils import (
|
||||||
StopWordsLogitsProcessor,
|
StopWordsLogitsProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
_CHECKPOINT_FOR_DOC = "qwen"
|
|
||||||
_CONFIG_FOR_DOC = "QWenConfig"
|
|
||||||
|
|
||||||
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
|
||||||
|
|
||||||
_ERROR_BAD_CHAT_FORMAT = """\
|
_ERROR_BAD_CHAT_FORMAT = """\
|
||||||
|
@ -71,55 +67,8 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo
|
||||||
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
|
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\
|
|
||||||
We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained).
|
|
||||||
检测到您的模型已激活了flash attention支持,但正在执行CPU运算任务。如使用flash attention,请您确认模型输入已经传到GPU上。如果您确认要执行CPU运算,请您在载入模型(调用AutoModelForCausalLM.from_pretrained)时,按照readme说法,指定device_map="cpu"以禁用flash attention。
|
|
||||||
"""
|
|
||||||
|
|
||||||
apply_rotary_emb_func = None
|
apply_rotary_emb_func = None
|
||||||
rms_norm = None
|
rms_norm = None
|
||||||
flash_attn_unpadded_func = None
|
|
||||||
flash_attn_func = None
|
|
||||||
|
|
||||||
def _import_flash_attn():
|
|
||||||
global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func, flash_attn_func
|
|
||||||
try:
|
|
||||||
from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
|
|
||||||
apply_rotary_emb_func = __apply_rotary_emb_func
|
|
||||||
except ImportError:
|
|
||||||
logger.warn(
|
|
||||||
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
|
|
||||||
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from flash_attn.ops.rms_norm import rms_norm as __rms_norm
|
|
||||||
rms_norm = __rms_norm
|
|
||||||
except ImportError:
|
|
||||||
logger.warn(
|
|
||||||
"Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
|
|
||||||
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import flash_attn
|
|
||||||
_flash_attn_func = None
|
|
||||||
if not hasattr(flash_attn, '__version__'):
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
|
|
||||||
else:
|
|
||||||
if int(flash_attn.__version__.split(".")[0]) >= 2:
|
|
||||||
if int(flash_attn.__version__.split(".")[1]) >= 1:
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
|
|
||||||
else:
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
|
|
||||||
flash_attn_unpadded_func = __flash_attn_unpadded_func
|
|
||||||
flash_attn_func = _flash_attn_func
|
|
||||||
except ImportError:
|
|
||||||
logger.warn(
|
|
||||||
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
|
|
||||||
"https://github.com/Dao-AILab/flash-attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
def quantize_cache_v(fdata, bits, qmax, qmin):
|
def quantize_cache_v(fdata, bits, qmax, qmin):
|
||||||
# b, s, head, h-dim->b, head, s, h-dim
|
# b, s, head, h-dim->b, head, s, h-dim
|
||||||
|
@ -147,104 +96,6 @@ def dequantize_cache_torch(qdata, scale, zero):
|
||||||
data = scale * (qdata - zero)
|
data = scale * (qdata - zero)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
class FlashSelfAttention(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
causal=False,
|
|
||||||
softmax_scale=None,
|
|
||||||
attention_dropout=0.0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
assert flash_attn_unpadded_func is not None, (
|
|
||||||
"Please install FlashAttention first, " "e.g., with pip install flash-attn"
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
rearrange is not None
|
|
||||||
), "Please install einops first, e.g., with pip install einops"
|
|
||||||
self.causal = causal
|
|
||||||
self.softmax_scale = softmax_scale
|
|
||||||
self.dropout_p = attention_dropout
|
|
||||||
|
|
||||||
def unpad_input(self, hidden_states, attention_mask):
|
|
||||||
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0)
|
|
||||||
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
|
|
||||||
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
|
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
||||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
|
||||||
hidden_states = hidden_states[indices]
|
|
||||||
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
|
|
||||||
|
|
||||||
def pad_input(self, hidden_states, indices, batch, seqlen):
|
|
||||||
output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype)
|
|
||||||
output[indices] = hidden_states
|
|
||||||
return rearrange(output, '(b s) ... -> b s ...', b=batch)
|
|
||||||
|
|
||||||
def forward(self, q, k, v, attention_mask=None):
|
|
||||||
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
|
|
||||||
assert all((i.is_cuda for i in (q, k, v)))
|
|
||||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
|
||||||
seqlen_k = k.shape[1]
|
|
||||||
seqlen_out = seqlen_q
|
|
||||||
|
|
||||||
if flash_attn_func is not None and batch_size == 1:
|
|
||||||
dropout_p = self.dropout_p if self.training else 0
|
|
||||||
output = flash_attn_func(q, k, v, dropout_p, softmax_scale=self.softmax_scale, causal=self.causal)
|
|
||||||
return output
|
|
||||||
|
|
||||||
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
|
||||||
cu_seqlens_q = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_q,
|
|
||||||
step=seqlen_q,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=q.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if batch_size > 1 and attention_mask is not None:
|
|
||||||
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
|
|
||||||
if q.size(0) == v.size(0):
|
|
||||||
q = q[indices_k]
|
|
||||||
cu_seqlens_q = cu_seqlens_k
|
|
||||||
seqlen_q = seqlen_k
|
|
||||||
v = v[indices_k]
|
|
||||||
else:
|
|
||||||
cu_seqlens_k = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size + 1) * seqlen_k,
|
|
||||||
step=seqlen_k,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=q.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
assert seqlen_k == seqlen_q
|
|
||||||
is_causal = self.causal
|
|
||||||
dropout_p = self.dropout_p
|
|
||||||
else:
|
|
||||||
is_causal = seqlen_q == seqlen_k
|
|
||||||
dropout_p = 0
|
|
||||||
|
|
||||||
output = flash_attn_unpadded_func(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
seqlen_q,
|
|
||||||
seqlen_k,
|
|
||||||
dropout_p,
|
|
||||||
softmax_scale=self.softmax_scale,
|
|
||||||
causal=is_causal,
|
|
||||||
)
|
|
||||||
if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
|
|
||||||
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
|
|
||||||
else:
|
|
||||||
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
|
|
||||||
output = output.view(new_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class QWenAttention(nn.Module):
|
class QWenAttention(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -257,7 +108,6 @@ class QWenAttention(nn.Module):
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
|
||||||
self.use_flash_attn = config.use_flash_attn
|
|
||||||
self.scale_attn_weights = True
|
self.scale_attn_weights = True
|
||||||
|
|
||||||
self.projection_size = config.kv_channels * config.num_attention_heads
|
self.projection_size = config.kv_channels * config.num_attention_heads
|
||||||
|
@ -274,14 +124,7 @@ class QWenAttention(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.is_fp32 = not (config.bf16 or config.fp16)
|
self.is_fp32 = not (config.bf16 or config.fp16)
|
||||||
if (
|
|
||||||
self.use_flash_attn
|
|
||||||
and flash_attn_unpadded_func is not None
|
|
||||||
and not self.is_fp32
|
|
||||||
):
|
|
||||||
self.core_attention_flash = FlashSelfAttention(
|
|
||||||
causal=True, attention_dropout=config.attn_dropout_prob
|
|
||||||
)
|
|
||||||
self.bf16 = config.bf16
|
self.bf16 = config.bf16
|
||||||
|
|
||||||
self.use_dynamic_ntk = config.use_dynamic_ntk
|
self.use_dynamic_ntk = config.use_dynamic_ntk
|
||||||
|
@ -490,15 +333,6 @@ class QWenAttention(nn.Module):
|
||||||
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query)
|
||||||
query = query * logn_tensor.expand_as(query)
|
query = query * logn_tensor.expand_as(query)
|
||||||
|
|
||||||
if (
|
|
||||||
self.use_flash_attn
|
|
||||||
and flash_attn_unpadded_func is not None
|
|
||||||
and not self.is_fp32
|
|
||||||
and query.is_cuda
|
|
||||||
):
|
|
||||||
q, k, v = query, key, value
|
|
||||||
attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
|
|
||||||
else:
|
|
||||||
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
|
key_size = key[0].size(2) if self.use_cache_quantization else key.size(1)
|
||||||
if query.size(1) == key_size:
|
if query.size(1) == key_size:
|
||||||
causal_mask = torch.tril(
|
causal_mask = torch.tril(
|
||||||
|
@ -510,14 +344,6 @@ class QWenAttention(nn.Module):
|
||||||
if not self.use_cache_quantization:
|
if not self.use_cache_quantization:
|
||||||
key = key.permute(0, 2, 1, 3)
|
key = key.permute(0, 2, 1, 3)
|
||||||
value = value.permute(0, 2, 1, 3)
|
value = value.permute(0, 2, 1, 3)
|
||||||
if (
|
|
||||||
causal_mask is None
|
|
||||||
and self.use_flash_attn
|
|
||||||
and flash_attn_unpadded_func is not None
|
|
||||||
and not self.is_fp32
|
|
||||||
and not query.is_cuda
|
|
||||||
):
|
|
||||||
raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
|
|
||||||
|
|
||||||
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
|
@ -544,13 +370,7 @@ class QWenAttention(nn.Module):
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
outputs = (attn_output, present)
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
if (
|
if not self.use_cache_quantization and SUPPORT_TORCH2:
|
||||||
self.use_flash_attn
|
|
||||||
and flash_attn_unpadded_func is not None
|
|
||||||
and not self.is_fp32
|
|
||||||
):
|
|
||||||
raise ValueError("Cannot output attentions while using flash-attn")
|
|
||||||
elif not self.use_cache_quantization and SUPPORT_TORCH2:
|
|
||||||
raise ValueError("Cannot output attentions while using scaled_dot_product_attention")
|
raise ValueError("Cannot output attentions while using scaled_dot_product_attention")
|
||||||
else:
|
else:
|
||||||
outputs += (attn_weight,)
|
outputs += (attn_weight,)
|
||||||
|
@ -711,7 +531,6 @@ class QWenModel(QWenPreTrainedModel):
|
||||||
)
|
)
|
||||||
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
|
||||||
|
|
||||||
self.use_flash_attn = config.use_flash_attn
|
|
||||||
self.is_fp32 = not (config.bf16 or config.fp16)
|
self.is_fp32 = not (config.bf16 or config.fp16)
|
||||||
|
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
|
@ -967,18 +786,6 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
elif SUPPORT_FP16:
|
elif SUPPORT_FP16:
|
||||||
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
|
||||||
|
|
||||||
if config.use_flash_attn == "auto":
|
|
||||||
if config.bf16 or config.fp16:
|
|
||||||
logger.warn("Try importing flash-attention for faster inference...")
|
|
||||||
config.use_flash_attn = True
|
|
||||||
else:
|
|
||||||
config.use_flash_attn = False
|
|
||||||
if config.use_flash_attn and config.fp32:
|
|
||||||
logger.warn("Flash attention will be disabled because it does NOT support fp32.")
|
|
||||||
|
|
||||||
if config.use_flash_attn:
|
|
||||||
_import_flash_attn()
|
|
||||||
|
|
||||||
self.transformer = QWenModel(config)
|
self.transformer = QWenModel(config)
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
@ -1073,19 +880,12 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
||||||
|
|
||||||
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
# loss_fct = CrossEntropyLoss()
|
# loss_fct = CrossEntropyLoss()
|
||||||
# loss = loss_fct(
|
# loss = loss_fct(
|
||||||
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# optimizer = torch.optim.Adam(self.parameters(), lr=2e-5)
|
|
||||||
# # optimizer = torch.optim.SGD(self.parameters(),lr=0.001)
|
|
||||||
# # pa = self.transformer.parameters()
|
|
||||||
|
|
||||||
# loss.backward()
|
# loss.backward()
|
||||||
# # optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|
Loading…
Reference in New Issue