Format qwen model.
This commit is contained in:
		
							parent
							
								
									255a2ff71c
								
							
						
					
					
						commit
						611396b656
					
				|  | @ -38,7 +38,9 @@ from torch import nn | ||||||
| SUPPORT_CUDA = torch.cuda.is_available() | SUPPORT_CUDA = torch.cuda.is_available() | ||||||
| SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() | SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() | ||||||
| SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 | SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 | ||||||
| SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 | SUPPORT_TORCH2 = ( | ||||||
|  |     hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2 | ||||||
|  | ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| from configuration_qwen import QWenConfig | from configuration_qwen import QWenConfig | ||||||
|  | @ -70,6 +72,7 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo | ||||||
| apply_rotary_emb_func = None | apply_rotary_emb_func = None | ||||||
| rms_norm = None | rms_norm = None | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def quantize_cache_v(fdata, bits, qmax, qmin): | def quantize_cache_v(fdata, bits, qmax, qmin): | ||||||
|     # b, s, head, h-dim->b, head, s, h-dim |     # b, s, head, h-dim->b, head, s, h-dim | ||||||
|     qtype = torch.uint8 |     qtype = torch.uint8 | ||||||
|  | @ -85,17 +88,19 @@ def quantize_cache_v(fdata, bits, qmax, qmin): | ||||||
|         qmin = qmin.to(device) |         qmin = qmin.to(device) | ||||||
|     scale = (fmax - fmin) / (qmax - qmin) |     scale = (fmax - fmin) / (qmax - qmin) | ||||||
|     zero = qmin - fmin / scale |     zero = qmin - fmin / scale | ||||||
|     scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() |     scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() | ||||||
|     zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() |     zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() | ||||||
|     # Quantize |     # Quantize | ||||||
|     res_data = fdata / scale + zero |     res_data = fdata / scale + zero | ||||||
|     qdata = torch.clamp(res_data, qmin, qmax).to(qtype) |     qdata = torch.clamp(res_data, qmin, qmax).to(qtype) | ||||||
|     return qdata.contiguous(), scale, zero |     return qdata.contiguous(), scale, zero | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| def dequantize_cache_torch(qdata, scale, zero): | def dequantize_cache_torch(qdata, scale, zero): | ||||||
|     data = scale * (qdata - zero) |     data = scale * (qdata - zero) | ||||||
|     return data |     return data | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| class QWenAttention(nn.Module): | class QWenAttention(nn.Module): | ||||||
|     def __init__(self, config): |     def __init__(self, config): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|  | @ -138,12 +143,20 @@ class QWenAttention(nn.Module): | ||||||
|         self.register_buffer("logn_tensor", logn_tensor, persistent=False) |         self.register_buffer("logn_tensor", logn_tensor, persistent=False) | ||||||
| 
 | 
 | ||||||
|         self.attn_dropout = nn.Dropout(config.attn_dropout_prob) |         self.attn_dropout = nn.Dropout(config.attn_dropout_prob) | ||||||
|         self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False |         self.softmax_in_fp32 = ( | ||||||
|         self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False |             config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False | ||||||
|         self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False |         ) | ||||||
|  |         self.use_cache_quantization = ( | ||||||
|  |             config.use_cache_quantization | ||||||
|  |             if hasattr(config, "use_cache_quantization") | ||||||
|  |             else False | ||||||
|  |         ) | ||||||
|  |         self.use_cache_kernel = ( | ||||||
|  |             config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False | ||||||
|  |         ) | ||||||
|         cache_dtype = torch.float |         cache_dtype = torch.float | ||||||
|         if self.bf16: |         if self.bf16: | ||||||
|             cache_dtype=torch.bfloat16 |             cache_dtype = torch.bfloat16 | ||||||
|         elif config.fp16: |         elif config.fp16: | ||||||
|             cache_dtype = torch.float16 |             cache_dtype = torch.float16 | ||||||
|         self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) |         self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) | ||||||
|  | @ -152,19 +165,25 @@ class QWenAttention(nn.Module): | ||||||
|         if config.use_cache_quantization and config.use_cache_kernel: |         if config.use_cache_quantization and config.use_cache_kernel: | ||||||
|             # pre check if the support files existing |             # pre check if the support files existing | ||||||
|             module_root = pathlib.Path(__file__).parent |             module_root = pathlib.Path(__file__).parent | ||||||
|             src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") |             src_files = ( | ||||||
|             if any(not (module_root/src).is_file() for src in src_files): |                 "cache_autogptq_cuda_256.cpp", | ||||||
|  |                 "cache_autogptq_cuda_kernel_256.cu", | ||||||
|  |             ) | ||||||
|  |             if any(not (module_root / src).is_file() for src in src_files): | ||||||
|                 warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") |                 warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") | ||||||
|                 self.cache_kernels = None |                 self.cache_kernels = None | ||||||
|             else: |             else: | ||||||
|                 try: |                 try: | ||||||
|                     from .cpp_kernels import cache_autogptq_cuda_256 |                     from .cpp_kernels import cache_autogptq_cuda_256 | ||||||
|  | 
 | ||||||
|                     self.cache_kernels = cache_autogptq_cuda_256 |                     self.cache_kernels = cache_autogptq_cuda_256 | ||||||
|                 except ImportError: |                 except ImportError: | ||||||
|                     warnings.warn("Failed to import KV cache kernels.") |                     warnings.warn("Failed to import KV cache kernels.") | ||||||
|                     self.cache_kernels = None |                     self.cache_kernels = None | ||||||
| 
 | 
 | ||||||
|     def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): |     def _attn( | ||||||
|  |         self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None | ||||||
|  |     ): | ||||||
|         device = query.device |         device = query.device | ||||||
|         if self.use_cache_quantization: |         if self.use_cache_quantization: | ||||||
|             qk, qk_scale, qk_zero = key |             qk, qk_scale, qk_zero = key | ||||||
|  | @ -172,11 +191,18 @@ class QWenAttention(nn.Module): | ||||||
|                 shape = query.shape[:-1] + (qk.shape[-2],) |                 shape = query.shape[:-1] + (qk.shape[-2],) | ||||||
|                 attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) |                 attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) | ||||||
|                 self.cache_kernels.vecquant8matmul_batched_faster_old( |                 self.cache_kernels.vecquant8matmul_batched_faster_old( | ||||||
|                     query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), |                     query.contiguous() | ||||||
|  |                     if query.dtype == torch.float16 | ||||||
|  |                     else query.to(torch.float16).contiguous(), | ||||||
|                     qk.transpose(-1, -2).contiguous(), |                     qk.transpose(-1, -2).contiguous(), | ||||||
|                     attn_weights, |                     attn_weights, | ||||||
|                     qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), |                     qk_scale.contiguous() | ||||||
|                     qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) |                     if qk_scale.dtype == torch.float16 | ||||||
|  |                     else qk_scale.to(torch.float16).contiguous(), | ||||||
|  |                     qk_zero.contiguous() | ||||||
|  |                     if qk_zero.dtype == torch.float16 | ||||||
|  |                     else qk_zero.to(torch.float16).contiguous(), | ||||||
|  |                 ) | ||||||
|                 # attn_weights = attn_weights.to(query.dtype).contiguous() |                 # attn_weights = attn_weights.to(query.dtype).contiguous() | ||||||
|             else: |             else: | ||||||
|                 key = dequantize_cache_torch(qk, qk_scale, qk_zero) |                 key = dequantize_cache_torch(qk, qk_scale, qk_zero) | ||||||
|  | @ -189,7 +215,7 @@ class QWenAttention(nn.Module): | ||||||
|                 size_temp = value[0].size(-1) |                 size_temp = value[0].size(-1) | ||||||
|             else: |             else: | ||||||
|                 size_temp = value.size(-1) |                 size_temp = value.size(-1) | ||||||
|             attn_weights = attn_weights / (size_temp ** 0.5) |             attn_weights = attn_weights / (size_temp**0.5) | ||||||
| 
 | 
 | ||||||
|         mask_value = torch.finfo(attn_weights.dtype).min |         mask_value = torch.finfo(attn_weights.dtype).min | ||||||
|         if causal_mask is not None: |         if causal_mask is not None: | ||||||
|  | @ -217,11 +243,18 @@ class QWenAttention(nn.Module): | ||||||
|                 shape = attn_weights.shape[:-1] + (query.shape[-1],) |                 shape = attn_weights.shape[:-1] + (query.shape[-1],) | ||||||
|                 attn_output = torch.zeros(shape, dtype=torch.float16, device=device) |                 attn_output = torch.zeros(shape, dtype=torch.float16, device=device) | ||||||
|                 self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( |                 self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( | ||||||
|                     attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), |                     attn_weights.contiguous() | ||||||
|  |                     if attn_weights.dtype == torch.float16 | ||||||
|  |                     else attn_weights.to(torch.float16).contiguous(), | ||||||
|                     qv.contiguous(),  # dtype: int32 |                     qv.contiguous(),  # dtype: int32 | ||||||
|                     attn_output, |                     attn_output, | ||||||
|                     qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), |                     qv_scale.contiguous() | ||||||
|                     qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) |                     if qv_scale.dtype == torch.float16 | ||||||
|  |                     else qv_scale.to(torch.float16).contiguous(), | ||||||
|  |                     qv_zero.contiguous() | ||||||
|  |                     if qv_zero.dtype == torch.float16 | ||||||
|  |                     else qv_zero.to(torch.float16).contiguous(), | ||||||
|  |                 ) | ||||||
|                 if attn_output.dtype != query.dtype: |                 if attn_output.dtype != query.dtype: | ||||||
|                     attn_output = attn_output.to(query.dtype) |                     attn_output = attn_output.to(query.dtype) | ||||||
|                     attn_weights = attn_weights.to(query.dtype) |                     attn_weights = attn_weights.to(query.dtype) | ||||||
|  | @ -283,21 +316,26 @@ class QWenAttention(nn.Module): | ||||||
|                     rotary_pos_emb = (rotary_pos_emb,) * 2 |                     rotary_pos_emb = (rotary_pos_emb,) * 2 | ||||||
|                     q_pos_emb, k_pos_emb = rotary_pos_emb |                     q_pos_emb, k_pos_emb = rotary_pos_emb | ||||||
|                     # Slice the pos emb for current inference |                     # Slice the pos emb for current inference | ||||||
|                     query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] |                     query_list += [ | ||||||
|                     key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] |                         apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb) | ||||||
|  |                     ] | ||||||
|  |                     key_list += [apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)] | ||||||
|                 query = torch.cat(query_list, dim=0) |                 query = torch.cat(query_list, dim=0) | ||||||
|                 key = torch.cat(key_list, dim=0) |                 key = torch.cat(key_list, dim=0) | ||||||
| 
 | 
 | ||||||
|         if self.use_cache_quantization: |         if self.use_cache_quantization: | ||||||
|             key = quantize_cache_v(key.permute(0, 2, 1, 3), |             key = quantize_cache_v( | ||||||
|                                        bits=8, |                 key.permute(0, 2, 1, 3), | ||||||
|                                        qmin=self.cache_qmin, |                 bits=8, | ||||||
|                                        qmax=self.cache_qmax) |                 qmin=self.cache_qmin, | ||||||
|             value = quantize_cache_v(value.permute(0, 2, 1, 3), |                 qmax=self.cache_qmax, | ||||||
|                                          bits=8, |             ) | ||||||
|                                          qmin=self.cache_qmin, |             value = quantize_cache_v( | ||||||
|                                          qmax=self.cache_qmax) |                 value.permute(0, 2, 1, 3), | ||||||
| 
 |                 bits=8, | ||||||
|  |                 qmin=self.cache_qmin, | ||||||
|  |                 qmax=self.cache_qmax, | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         if layer_past is not None: |         if layer_past is not None: | ||||||
|             past_key, past_value = layer_past[0], layer_past[1] |             past_key, past_value = layer_past[0], layer_past[1] | ||||||
|  | @ -305,12 +343,16 @@ class QWenAttention(nn.Module): | ||||||
|                 # use_cache_quantization: |                 # use_cache_quantization: | ||||||
|                 # present=((q_key,key_scale,key_zero_point), |                 # present=((q_key,key_scale,key_zero_point), | ||||||
|                 #          (q_value,value_scale,value_zero_point)) |                 #          (q_value,value_scale,value_zero_point)) | ||||||
|                 key = (torch.cat((past_key[0], key[0]), dim=2), |                 key = ( | ||||||
|                        torch.cat((past_key[1], key[1]), dim=2), |                     torch.cat((past_key[0], key[0]), dim=2), | ||||||
|                        torch.cat((past_key[2], key[2]), dim=2)) |                     torch.cat((past_key[1], key[1]), dim=2), | ||||||
|                 value = (torch.cat((past_value[0], value[0]), dim=2), |                     torch.cat((past_key[2], key[2]), dim=2), | ||||||
|                          torch.cat((past_value[1], value[1]), dim=2), |                 ) | ||||||
|                          torch.cat((past_value[2], value[2]), dim=2)) |                 value = ( | ||||||
|  |                     torch.cat((past_value[0], value[0]), dim=2), | ||||||
|  |                     torch.cat((past_value[1], value[1]), dim=2), | ||||||
|  |                     torch.cat((past_value[2], value[2]), dim=2), | ||||||
|  |                 ) | ||||||
|             else: |             else: | ||||||
|                 # not use_cache_quantization: |                 # not use_cache_quantization: | ||||||
|                 # present=(key,value) |                 # present=(key,value) | ||||||
|  | @ -347,11 +389,11 @@ class QWenAttention(nn.Module): | ||||||
| 
 | 
 | ||||||
|         if not self.use_cache_quantization and SUPPORT_TORCH2: |         if not self.use_cache_quantization and SUPPORT_TORCH2: | ||||||
|             if attention_mask is not None: |             if attention_mask is not None: | ||||||
|                 attention_mask = attention_mask.expand( |                 attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1) | ||||||
|                     -1, -1, causal_mask.size(2), -1 |  | ||||||
|                 ) |  | ||||||
|                 if causal_mask is not None: |                 if causal_mask is not None: | ||||||
|                     attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) |                     attention_mask = attention_mask.masked_fill( | ||||||
|  |                         ~causal_mask, torch.finfo(query.dtype).min | ||||||
|  |                     ) | ||||||
|             else: |             else: | ||||||
|                 attention_mask = causal_mask |                 attention_mask = causal_mask | ||||||
|             attn_output = F.scaled_dot_product_attention( |             attn_output = F.scaled_dot_product_attention( | ||||||
|  | @ -362,16 +404,16 @@ class QWenAttention(nn.Module): | ||||||
|             attn_output, attn_weight = self._attn( |             attn_output, attn_weight = self._attn( | ||||||
|                 query, key, value, causal_mask, attention_mask, head_mask |                 query, key, value, causal_mask, attention_mask, head_mask | ||||||
|             ) |             ) | ||||||
|         context_layer = self._merge_heads( |         context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) | ||||||
|             attn_output, self.num_heads, self.head_dim |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         attn_output = self.c_proj(context_layer) |         attn_output = self.c_proj(context_layer) | ||||||
| 
 | 
 | ||||||
|         outputs = (attn_output, present) |         outputs = (attn_output, present) | ||||||
|         if output_attentions: |         if output_attentions: | ||||||
|             if not self.use_cache_quantization and SUPPORT_TORCH2: |             if not self.use_cache_quantization and SUPPORT_TORCH2: | ||||||
|                 raise ValueError("Cannot output attentions while using scaled_dot_product_attention") |                 raise ValueError( | ||||||
|  |                     "Cannot output attentions while using scaled_dot_product_attention" | ||||||
|  |                 ) | ||||||
|             else: |             else: | ||||||
|                 outputs += (attn_weight,) |                 outputs += (attn_weight,) | ||||||
| 
 | 
 | ||||||
|  | @ -507,7 +549,11 @@ class QWenModel(QWenPreTrainedModel): | ||||||
|         self.vocab_size = config.vocab_size |         self.vocab_size = config.vocab_size | ||||||
|         self.num_hidden_layers = config.num_hidden_layers |         self.num_hidden_layers = config.num_hidden_layers | ||||||
|         self.embed_dim = config.hidden_size |         self.embed_dim = config.hidden_size | ||||||
|         self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False |         self.use_cache_quantization = ( | ||||||
|  |             self.config.use_cache_quantization | ||||||
|  |             if hasattr(self.config, "use_cache_quantization") | ||||||
|  |             else False | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         self.gradient_checkpointing = False |         self.gradient_checkpointing = False | ||||||
|         self.use_dynamic_ntk = config.use_dynamic_ntk |         self.use_dynamic_ntk = config.use_dynamic_ntk | ||||||
|  | @ -521,25 +567,14 @@ class QWenModel(QWenPreTrainedModel): | ||||||
|             self.rotary_ndims = None |             self.rotary_ndims = None | ||||||
|         else: |         else: | ||||||
|             assert config.rotary_pct < 1 |             assert config.rotary_pct < 1 | ||||||
|             self.rotary_ndims = int( |             self.rotary_ndims = int(config.kv_channels * config.rotary_pct) | ||||||
|                 config.kv_channels * config.rotary_pct |         dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels | ||||||
|             ) |  | ||||||
|         dim = ( |  | ||||||
|             self.rotary_ndims |  | ||||||
|             if self.rotary_ndims is not None |  | ||||||
|             else config.kv_channels |  | ||||||
|         ) |  | ||||||
|         self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) |         self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) | ||||||
| 
 | 
 | ||||||
|         self.is_fp32 = not (config.bf16 or config.fp16) |         self.is_fp32 = not (config.bf16 or config.fp16) | ||||||
| 
 | 
 | ||||||
|         self.h = nn.ModuleList( |         self.h = nn.ModuleList( | ||||||
|             [ |             [QWenBlock(config) for i in range(config.num_hidden_layers)] | ||||||
|                 QWenBlock( |  | ||||||
|                     config |  | ||||||
|                 ) |  | ||||||
|                 for i in range(config.num_hidden_layers) |  | ||||||
|             ] |  | ||||||
|         ) |         ) | ||||||
|         self.ln_f = RMSNorm( |         self.ln_f = RMSNorm( | ||||||
|             self.embed_dim, |             self.embed_dim, | ||||||
|  | @ -659,7 +694,12 @@ class QWenModel(QWenPreTrainedModel): | ||||||
|         else: |         else: | ||||||
|             ntk_alpha_list = [] |             ntk_alpha_list = [] | ||||||
|             if attention_mask is not None and kv_seq_len > self.seq_length: |             if attention_mask is not None and kv_seq_len > self.seq_length: | ||||||
|                 true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) |                 true_seq_lens = ( | ||||||
|  |                     attention_mask.squeeze(1) | ||||||
|  |                     .squeeze(1) | ||||||
|  |                     .eq(0) | ||||||
|  |                     .sum(dim=-1, dtype=torch.int32) | ||||||
|  |                 ) | ||||||
|                 for i in range(hidden_states.size()[0]): |                 for i in range(hidden_states.size()[0]): | ||||||
|                     true_seq_len = true_seq_lens[i].item() |                     true_seq_len = true_seq_lens[i].item() | ||||||
|                     ntk_alpha = self.get_ntk_alpha(true_seq_len) |                     ntk_alpha = self.get_ntk_alpha(true_seq_len) | ||||||
|  | @ -669,7 +709,8 @@ class QWenModel(QWenPreTrainedModel): | ||||||
|                 ntk_alpha_list.append(ntk_alpha) |                 ntk_alpha_list.append(ntk_alpha) | ||||||
|         self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list |         self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list | ||||||
|         rotary_pos_emb_list = [ |         rotary_pos_emb_list = [ | ||||||
|             self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list |             self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) | ||||||
|  |             for ntk_alpha in ntk_alpha_list | ||||||
|         ] |         ] | ||||||
| 
 | 
 | ||||||
|         hidden_states = self.drop(hidden_states) |         hidden_states = self.drop(hidden_states) | ||||||
|  | @ -686,7 +727,6 @@ class QWenModel(QWenPreTrainedModel): | ||||||
|         all_self_attentions = () if output_attentions else None |         all_self_attentions = () if output_attentions else None | ||||||
|         all_hidden_states = () if output_hidden_states else None |         all_hidden_states = () if output_hidden_states else None | ||||||
|         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): |         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): | ||||||
| 
 |  | ||||||
|             if output_hidden_states: |             if output_hidden_states: | ||||||
|                 all_hidden_states = all_hidden_states + (hidden_states,) |                 all_hidden_states = all_hidden_states + (hidden_states,) | ||||||
| 
 | 
 | ||||||
|  | @ -727,7 +767,9 @@ class QWenModel(QWenPreTrainedModel): | ||||||
|                 presents = presents + (outputs[1],) |                 presents = presents + (outputs[1],) | ||||||
| 
 | 
 | ||||||
|             if output_attentions: |             if output_attentions: | ||||||
|                 all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) |                 all_self_attentions = all_self_attentions + ( | ||||||
|  |                     outputs[2 if use_cache else 1], | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|         hidden_states = self.ln_f(hidden_states) |         hidden_states = self.ln_f(hidden_states) | ||||||
|         hidden_states = hidden_states.view(output_shape) |         hidden_states = hidden_states.view(output_shape) | ||||||
|  | @ -756,7 +798,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         super().__init__(config) |         super().__init__(config) | ||||||
|         assert ( |         assert ( | ||||||
|             config.bf16 + config.fp16 + config.fp32 <= 1 |             config.bf16 + config.fp16 + config.fp32 <= 1 | ||||||
|         ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" |         ), 'Only one of "bf16", "fp16", "fp32" can be true' | ||||||
| 
 | 
 | ||||||
|         autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 |         autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 | ||||||
| 
 | 
 | ||||||
|  | @ -764,27 +806,35 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             if SUPPORT_BF16: |             if SUPPORT_BF16: | ||||||
|                 logger.warn( |                 logger.warn( | ||||||
|                     "The model is automatically converting to bf16 for faster inference. " |                     "The model is automatically converting to bf16 for faster inference. " | ||||||
|                     "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." |                     'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".' | ||||||
|                 ) |                 ) | ||||||
|                 config.bf16 = True |                 config.bf16 = True | ||||||
|             elif SUPPORT_FP16: |             elif SUPPORT_FP16: | ||||||
|                 logger.warn( |                 logger.warn( | ||||||
|                     "The model is automatically converting to fp16 for faster inference. " |                     "The model is automatically converting to fp16 for faster inference. " | ||||||
|                     "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." |                     'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".' | ||||||
|                 ) |                 ) | ||||||
|                 config.fp16 = True |                 config.fp16 = True | ||||||
|             else: |             else: | ||||||
|                 config.fp32 = True |                 config.fp32 = True | ||||||
| 
 | 
 | ||||||
|         if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: |         if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: | ||||||
|             logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") |             logger.warn( | ||||||
|  |                 'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".' | ||||||
|  |             ) | ||||||
|         if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: |         if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: | ||||||
|             logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") |             logger.warn( | ||||||
|  |                 "Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster" | ||||||
|  |             ) | ||||||
|         if config.fp32: |         if config.fp32: | ||||||
|             if SUPPORT_BF16: |             if SUPPORT_BF16: | ||||||
|                 logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") |                 logger.warn( | ||||||
|  |                     'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".' | ||||||
|  |                 ) | ||||||
|             elif SUPPORT_FP16: |             elif SUPPORT_FP16: | ||||||
|                 logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") |                 logger.warn( | ||||||
|  |                     'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".' | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|         self.transformer = QWenModel(config) |         self.transformer = QWenModel(config) | ||||||
|         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |         self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | ||||||
|  | @ -845,7 +895,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         output_hidden_states: Optional[bool] = None, |         output_hidden_states: Optional[bool] = None, | ||||||
|         return_dict: Optional[bool] = None, |         return_dict: Optional[bool] = None, | ||||||
|     ) -> Union[Tuple, CausalLMOutputWithPast]: |     ) -> Union[Tuple, CausalLMOutputWithPast]: | ||||||
| 
 |  | ||||||
|         return_dict = ( |         return_dict = ( | ||||||
|             return_dict if return_dict is not None else self.config.use_return_dict |             return_dict if return_dict is not None else self.config.use_return_dict | ||||||
|         ) |         ) | ||||||
|  | @ -887,7 +936,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         # ) |         # ) | ||||||
|         # loss.backward() |         # loss.backward() | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|         if not return_dict: |         if not return_dict: | ||||||
|             output = (lm_logits,) + transformer_outputs[1:] |             output = (lm_logits,) + transformer_outputs[1:] | ||||||
|             return ((loss,) + output) if loss is not None else output |             return ((loss,) + output) if loss is not None else output | ||||||
|  | @ -904,7 +952,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|     def _reorder_cache( |     def _reorder_cache( | ||||||
|         past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor |         past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor | ||||||
|     ) -> Tuple[Tuple[torch.Tensor]]: |     ) -> Tuple[Tuple[torch.Tensor]]: | ||||||
| 
 |  | ||||||
|         return tuple( |         return tuple( | ||||||
|             tuple( |             tuple( | ||||||
|                 past_state.index_select(0, beam_idx.to(past_state.device)) |                 past_state.index_select(0, beam_idx.to(past_state.device)) | ||||||
|  | @ -924,10 +971,14 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         generation_config: Optional[GenerationConfig] = None, |         generation_config: Optional[GenerationConfig] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ) -> Tuple[str, HistoryType]: |     ) -> Tuple[str, HistoryType]: | ||||||
|         generation_config = generation_config if generation_config is not None else self.generation_config |         generation_config = ( | ||||||
|  |             generation_config | ||||||
|  |             if generation_config is not None | ||||||
|  |             else self.generation_config | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT |         assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT | ||||||
|         assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT |         assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT | ||||||
|         if history is None: |         if history is None: | ||||||
|             history = [] |             history = [] | ||||||
|         else: |         else: | ||||||
|  | @ -937,7 +988,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         if stop_words_ids is None: |         if stop_words_ids is None: | ||||||
|             stop_words_ids = [] |             stop_words_ids = [] | ||||||
| 
 | 
 | ||||||
|         max_window_size = kwargs.get('max_window_size', None) |         max_window_size = kwargs.get("max_window_size", None) | ||||||
|         if max_window_size is None: |         if max_window_size is None: | ||||||
|             max_window_size = generation_config.max_window_size |             max_window_size = generation_config.max_window_size | ||||||
|         raw_text, context_tokens = make_context( |         raw_text, context_tokens = make_context( | ||||||
|  | @ -949,17 +1000,17 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             chat_format=generation_config.chat_format, |             chat_format=generation_config.chat_format, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         stop_words_ids.extend(get_stop_words_ids( |         stop_words_ids.extend( | ||||||
|             generation_config.chat_format, tokenizer |             get_stop_words_ids(generation_config.chat_format, tokenizer) | ||||||
|         )) |         ) | ||||||
|         input_ids = torch.tensor([context_tokens]).to(self.device) |         input_ids = torch.tensor([context_tokens]).to(self.device) | ||||||
|         outputs = self.generate( |         outputs = self.generate( | ||||||
|                     input_ids, |             input_ids, | ||||||
|                     stop_words_ids=stop_words_ids, |             stop_words_ids=stop_words_ids, | ||||||
|                     return_dict_in_generate=False, |             return_dict_in_generate=False, | ||||||
|                     generation_config=generation_config, |             generation_config=generation_config, | ||||||
|                     **kwargs, |             **kwargs, | ||||||
|                 ) |         ) | ||||||
| 
 | 
 | ||||||
|         response = decode_tokens( |         response = decode_tokens( | ||||||
|             outputs[0], |             outputs[0], | ||||||
|  | @ -968,7 +1019,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             context_length=len(context_tokens), |             context_length=len(context_tokens), | ||||||
|             chat_format=generation_config.chat_format, |             chat_format=generation_config.chat_format, | ||||||
|             verbose=False, |             verbose=False, | ||||||
|             errors='replace' |             errors="replace", | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         # as history is a copy of the user inputs, |         # as history is a copy of the user inputs, | ||||||
|  | @ -980,24 +1031,28 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         return response, history |         return response, history | ||||||
| 
 | 
 | ||||||
|     def chat_stream( |     def chat_stream( | ||||||
|             self, |         self, | ||||||
|             tokenizer: PreTrainedTokenizer, |         tokenizer: PreTrainedTokenizer, | ||||||
|             query: str, |         query: str, | ||||||
|             history: Optional[HistoryType], |         history: Optional[HistoryType], | ||||||
|             system: str = "You are a helpful assistant.", |         system: str = "You are a helpful assistant.", | ||||||
|             stop_words_ids: Optional[List[List[int]]] = None, |         stop_words_ids: Optional[List[List[int]]] = None, | ||||||
|             logits_processor: Optional[LogitsProcessorList] = None, |         logits_processor: Optional[LogitsProcessorList] = None, | ||||||
|             generation_config: Optional[GenerationConfig] = None, |         generation_config: Optional[GenerationConfig] = None, | ||||||
|             **kwargs, |         **kwargs, | ||||||
|     ) -> Generator[str, Any, None]: |     ) -> Generator[str, Any, None]: | ||||||
|         generation_config = generation_config if generation_config is not None else self.generation_config |         generation_config = ( | ||||||
|         assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT |             generation_config | ||||||
|  |             if generation_config is not None | ||||||
|  |             else self.generation_config | ||||||
|  |         ) | ||||||
|  |         assert generation_config.chat_format == "chatml", _ERROR_BAD_CHAT_FORMAT | ||||||
|         if history is None: |         if history is None: | ||||||
|             history = [] |             history = [] | ||||||
|         if stop_words_ids is None: |         if stop_words_ids is None: | ||||||
|             stop_words_ids = [] |             stop_words_ids = [] | ||||||
| 
 | 
 | ||||||
|         max_window_size = kwargs.get('max_window_size', None) |         max_window_size = kwargs.get("max_window_size", None) | ||||||
|         if max_window_size is None: |         if max_window_size is None: | ||||||
|             max_window_size = generation_config.max_window_size |             max_window_size = generation_config.max_window_size | ||||||
|         raw_text, context_tokens = make_context( |         raw_text, context_tokens = make_context( | ||||||
|  | @ -1009,9 +1064,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             chat_format=generation_config.chat_format, |             chat_format=generation_config.chat_format, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         stop_words_ids.extend(get_stop_words_ids( |         stop_words_ids.extend( | ||||||
|             generation_config.chat_format, tokenizer |             get_stop_words_ids(generation_config.chat_format, tokenizer) | ||||||
|         )) |         ) | ||||||
|         if stop_words_ids is not None: |         if stop_words_ids is not None: | ||||||
|             stop_words_logits_processor = StopWordsLogitsProcessor( |             stop_words_logits_processor = StopWordsLogitsProcessor( | ||||||
|                 stop_words_ids=stop_words_ids, |                 stop_words_ids=stop_words_ids, | ||||||
|  | @ -1023,22 +1078,31 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|                 logits_processor.append(stop_words_logits_processor) |                 logits_processor.append(stop_words_logits_processor) | ||||||
|         input_ids = torch.tensor([context_tokens]).to(self.device) |         input_ids = torch.tensor([context_tokens]).to(self.device) | ||||||
| 
 | 
 | ||||||
|         from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig |         from transformers_stream_generator.main import ( | ||||||
|  |             NewGenerationMixin, | ||||||
|  |             StreamGenerationConfig, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|         self.__class__.generate_stream = NewGenerationMixin.generate |         self.__class__.generate_stream = NewGenerationMixin.generate | ||||||
|         self.__class__.sample_stream = NewGenerationMixin.sample_stream |         self.__class__.sample_stream = NewGenerationMixin.sample_stream | ||||||
|         stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) |         stream_config = StreamGenerationConfig( | ||||||
|  |             **generation_config.to_dict(), do_stream=True | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         def stream_generator(): |         def stream_generator(): | ||||||
|             outputs = [] |             outputs = [] | ||||||
|             for token in self.generate_stream( |             for token in self.generate_stream( | ||||||
|                     input_ids, |                 input_ids, | ||||||
|                     return_dict_in_generate=False, |                 return_dict_in_generate=False, | ||||||
|                     generation_config=stream_config, |                 generation_config=stream_config, | ||||||
|                     logits_processor=logits_processor, |                 logits_processor=logits_processor, | ||||||
|                     seed=-1, |                 seed=-1, | ||||||
|                     **kwargs): |                 **kwargs, | ||||||
|  |             ): | ||||||
|                 outputs.append(token.item()) |                 outputs.append(token.item()) | ||||||
|                 yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') |                 yield tokenizer.decode( | ||||||
|  |                     outputs, skip_special_tokens=True, errors="ignore" | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|         return stream_generator() |         return stream_generator() | ||||||
| 
 | 
 | ||||||
|  | @ -1056,7 +1120,11 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         streamer: Optional["BaseStreamer"] = None, |         streamer: Optional["BaseStreamer"] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ) -> Union[GenerateOutput, torch.LongTensor]: |     ) -> Union[GenerateOutput, torch.LongTensor]: | ||||||
|         generation_config = generation_config if generation_config is not None else self.generation_config |         generation_config = ( | ||||||
|  |             generation_config | ||||||
|  |             if generation_config is not None | ||||||
|  |             else self.generation_config | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # Process stop_words_ids. |         # Process stop_words_ids. | ||||||
|         stop_words_ids = kwargs.pop("stop_words_ids", None) |         stop_words_ids = kwargs.pop("stop_words_ids", None) | ||||||
|  | @ -1093,7 +1161,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         generation_config: Optional[GenerationConfig] = None, |         generation_config: Optional[GenerationConfig] = None, | ||||||
|         logits_processor: Optional[LogitsProcessorList] = None, |         logits_processor: Optional[LogitsProcessorList] = None, | ||||||
|         stopping_criteria: Optional[StoppingCriteriaList] = None, |         stopping_criteria: Optional[StoppingCriteriaList] = None, | ||||||
|         prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |         prefix_allowed_tokens_fn: Optional[ | ||||||
|  |             Callable[[int, torch.Tensor], List[int]] | ||||||
|  |         ] = None, | ||||||
|         synced_gpus: Optional[bool] = None, |         synced_gpus: Optional[bool] = None, | ||||||
|         assistant_model: Optional["PreTrainedModel"] = None, |         assistant_model: Optional["PreTrainedModel"] = None, | ||||||
|         streamer: Optional["BaseStreamer"] = None, |         streamer: Optional["BaseStreamer"] = None, | ||||||
|  | @ -1101,9 +1171,8 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         negative_prompt_attention_mask: Optional[torch.Tensor] = None, |         negative_prompt_attention_mask: Optional[torch.Tensor] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ) -> Union[GenerateOutput, torch.LongTensor]: |     ) -> Union[GenerateOutput, torch.LongTensor]: | ||||||
| 
 |  | ||||||
|         if synced_gpus is None: |         if synced_gpus is None: | ||||||
|                 synced_gpus = False |             synced_gpus = False | ||||||
| 
 | 
 | ||||||
|         # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call |         # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call | ||||||
|         self._validate_model_class() |         self._validate_model_class() | ||||||
|  | @ -1114,8 +1183,10 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             # two conditions must be met |             # two conditions must be met | ||||||
|             # 1) the generation config must have been created from the model config (`_from_model_config` field); |             # 1) the generation config must have been created from the model config (`_from_model_config` field); | ||||||
|             # 2) the generation config must have seen no modification since its creation (the hash is the same). |             # 2) the generation config must have seen no modification since its creation (the hash is the same). | ||||||
|             if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash( |             if ( | ||||||
|                 self.generation_config |                 self.generation_config._from_model_config | ||||||
|  |                 and self.generation_config._original_object_hash | ||||||
|  |                 == hash(self.generation_config) | ||||||
|             ): |             ): | ||||||
|                 new_generation_config = GenerationConfig.from_model_config(self.config) |                 new_generation_config = GenerationConfig.from_model_config(self.config) | ||||||
|                 if new_generation_config != self.generation_config: |                 if new_generation_config != self.generation_config: | ||||||
|  | @ -1129,15 +1200,26 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             generation_config = self.generation_config |             generation_config = self.generation_config | ||||||
| 
 | 
 | ||||||
|         generation_config = copy.deepcopy(generation_config) |         generation_config = copy.deepcopy(generation_config) | ||||||
|         model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs |         model_kwargs = generation_config.update( | ||||||
|  |             **kwargs | ||||||
|  |         )  # All unused kwargs must be model kwargs | ||||||
|         generation_config.validate() |         generation_config.validate() | ||||||
|         self._validate_model_kwargs(model_kwargs.copy()) |         self._validate_model_kwargs(model_kwargs.copy()) | ||||||
| 
 | 
 | ||||||
|         # 2. Set generation parameters if not already defined |         # 2. Set generation parameters if not already defined | ||||||
|         logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |         logits_processor = ( | ||||||
|         stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |             logits_processor if logits_processor is not None else LogitsProcessorList() | ||||||
|  |         ) | ||||||
|  |         stopping_criteria = ( | ||||||
|  |             stopping_criteria | ||||||
|  |             if stopping_criteria is not None | ||||||
|  |             else StoppingCriteriaList() | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         if generation_config.pad_token_id is None and generation_config.eos_token_id is not None: |         if ( | ||||||
|  |             generation_config.pad_token_id is None | ||||||
|  |             and generation_config.eos_token_id is not None | ||||||
|  |         ): | ||||||
|             if model_kwargs.get("attention_mask", None) is None: |             if model_kwargs.get("attention_mask", None) is None: | ||||||
|                 logger.warning( |                 logger.warning( | ||||||
|                     "The attention mask and the pad token id were not set. As a consequence, you may observe " |                     "The attention mask and the pad token id were not set. As a consequence, you may observe " | ||||||
|  | @ -1146,7 +1228,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             eos_token_id = generation_config.eos_token_id |             eos_token_id = generation_config.eos_token_id | ||||||
|             if isinstance(eos_token_id, list): |             if isinstance(eos_token_id, list): | ||||||
|                 eos_token_id = eos_token_id[0] |                 eos_token_id = eos_token_id[0] | ||||||
|             logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") |             logger.warning( | ||||||
|  |                 f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation." | ||||||
|  |             ) | ||||||
|             generation_config.pad_token_id = eos_token_id |             generation_config.pad_token_id = eos_token_id | ||||||
| 
 | 
 | ||||||
|         # 3. Define model inputs |         # 3. Define model inputs | ||||||
|  | @ -1169,12 +1253,22 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         else: |         else: | ||||||
|             model_kwargs["use_cache"] = generation_config.use_cache |             model_kwargs["use_cache"] = generation_config.use_cache | ||||||
| 
 | 
 | ||||||
|         accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) |         accepts_attention_mask = "attention_mask" in set( | ||||||
|  |             inspect.signature(self.forward).parameters.keys() | ||||||
|  |         ) | ||||||
|         requires_attention_mask = "encoder_outputs" not in model_kwargs |         requires_attention_mask = "encoder_outputs" not in model_kwargs | ||||||
| 
 | 
 | ||||||
|         if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask: |         if ( | ||||||
|             model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( |             model_kwargs.get("attention_mask", None) is None | ||||||
|                 inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id |             and requires_attention_mask | ||||||
|  |             and accepts_attention_mask | ||||||
|  |         ): | ||||||
|  |             model_kwargs[ | ||||||
|  |                 "attention_mask" | ||||||
|  |             ] = self._prepare_attention_mask_for_generation( | ||||||
|  |                 inputs_tensor, | ||||||
|  |                 generation_config.pad_token_id, | ||||||
|  |                 generation_config.eos_token_id, | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         # decoder-only models should use left-padding for generation |         # decoder-only models should use left-padding for generation | ||||||
|  | @ -1184,7 +1278,8 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             if ( |             if ( | ||||||
|                 generation_config.pad_token_id is not None |                 generation_config.pad_token_id is not None | ||||||
|                 and len(inputs_tensor.shape) == 2 |                 and len(inputs_tensor.shape) == 2 | ||||||
|                 and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0 |                 and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) | ||||||
|  |                 > 0 | ||||||
|             ): |             ): | ||||||
|                 logger.warning( |                 logger.warning( | ||||||
|                     "A decoder-only architecture is being used, but right-padding was detected! For correct " |                     "A decoder-only architecture is being used, but right-padding was detected! For correct " | ||||||
|  | @ -1209,14 +1304,21 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|                 device=inputs_tensor.device, |                 device=inputs_tensor.device, | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") |             input_ids = ( | ||||||
|  |                 inputs_tensor | ||||||
|  |                 if model_input_name == "input_ids" | ||||||
|  |                 else model_kwargs.pop("input_ids") | ||||||
|  |             ) | ||||||
| 
 | 
 | ||||||
|         if streamer is not None: |         if streamer is not None: | ||||||
|             streamer.put(input_ids.cpu()) |             streamer.put(input_ids.cpu()) | ||||||
| 
 | 
 | ||||||
|         # 6. Prepare `max_length` depending on other stopping criteria. |         # 6. Prepare `max_length` depending on other stopping criteria. | ||||||
|         input_ids_length = input_ids.shape[-1] |         input_ids_length = input_ids.shape[-1] | ||||||
|         has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None |         has_default_max_length = ( | ||||||
|  |             kwargs.get("max_length") is None | ||||||
|  |             and generation_config.max_length is not None | ||||||
|  |         ) | ||||||
|         if generation_config.max_new_tokens is not None: |         if generation_config.max_new_tokens is not None: | ||||||
|             if not has_default_max_length and generation_config.max_length is not None: |             if not has_default_max_length and generation_config.max_length is not None: | ||||||
|                 logger.warning( |                 logger.warning( | ||||||
|  | @ -1225,8 +1327,12 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|                     "Please refer to the documentation for more information. " |                     "Please refer to the documentation for more information. " | ||||||
|                     "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" |                     "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" | ||||||
|                 ) |                 ) | ||||||
|             generation_config.max_length = generation_config.max_new_tokens + input_ids_length |             generation_config.max_length = ( | ||||||
|         self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) |                 generation_config.max_new_tokens + input_ids_length | ||||||
|  |             ) | ||||||
|  |         self._validate_generated_length( | ||||||
|  |             generation_config, input_ids_length, has_default_max_length | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # 7. determine generation mode |         # 7. determine generation mode | ||||||
|         generation_mode = self._get_generation_mode(generation_config, assistant_model) |         generation_mode = self._get_generation_mode(generation_config, assistant_model) | ||||||
|  | @ -1291,7 +1397,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             **model_kwargs, |             **model_kwargs, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|      |  | ||||||
|     def sample_base( |     def sample_base( | ||||||
|         self, |         self, | ||||||
|         input_ids: torch.LongTensor, |         input_ids: torch.LongTensor, | ||||||
|  | @ -1309,10 +1414,15 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         streamer: Optional["BaseStreamer"] = None, |         streamer: Optional["BaseStreamer"] = None, | ||||||
|         **model_kwargs, |         **model_kwargs, | ||||||
|     ): |     ): | ||||||
| 
 |  | ||||||
|         # init values |         # init values | ||||||
|         logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |         logits_processor = ( | ||||||
|         stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |             logits_processor if logits_processor is not None else LogitsProcessorList() | ||||||
|  |         ) | ||||||
|  |         stopping_criteria = ( | ||||||
|  |             stopping_criteria | ||||||
|  |             if stopping_criteria is not None | ||||||
|  |             else StoppingCriteriaList() | ||||||
|  |         ) | ||||||
|         # if max_length is not None: |         # if max_length is not None: | ||||||
|         #     warnings.warn( |         #     warnings.warn( | ||||||
|         #         "`max_length` is deprecated in this function, use" |         #         "`max_length` is deprecated in this function, use" | ||||||
|  | @ -1320,18 +1430,40 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|         #         UserWarning, |         #         UserWarning, | ||||||
|         #     ) |         #     ) | ||||||
|         #     stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) |         #     stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | ||||||
|         logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() |         logits_warper = ( | ||||||
|         pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id |             logits_warper if logits_warper is not None else LogitsProcessorList() | ||||||
|         eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id |         ) | ||||||
|  |         pad_token_id = ( | ||||||
|  |             pad_token_id | ||||||
|  |             if pad_token_id is not None | ||||||
|  |             else self.generation_config.pad_token_id | ||||||
|  |         ) | ||||||
|  |         eos_token_id = ( | ||||||
|  |             eos_token_id | ||||||
|  |             if eos_token_id is not None | ||||||
|  |             else self.generation_config.eos_token_id | ||||||
|  |         ) | ||||||
|         if isinstance(eos_token_id, int): |         if isinstance(eos_token_id, int): | ||||||
|             eos_token_id = [eos_token_id] |             eos_token_id = [eos_token_id] | ||||||
|         eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None |         eos_token_id_tensor = ( | ||||||
|         output_scores = output_scores if output_scores is not None else self.generation_config.output_scores |             torch.tensor(eos_token_id).to(input_ids.device) | ||||||
|  |             if eos_token_id is not None | ||||||
|  |             else None | ||||||
|  |         ) | ||||||
|  |         output_scores = ( | ||||||
|  |             output_scores | ||||||
|  |             if output_scores is not None | ||||||
|  |             else self.generation_config.output_scores | ||||||
|  |         ) | ||||||
|         output_attentions = ( |         output_attentions = ( | ||||||
|             output_attentions if output_attentions is not None else self.generation_config.output_attentions |             output_attentions | ||||||
|  |             if output_attentions is not None | ||||||
|  |             else self.generation_config.output_attentions | ||||||
|         ) |         ) | ||||||
|         output_hidden_states = ( |         output_hidden_states = ( | ||||||
|             output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states |             output_hidden_states | ||||||
|  |             if output_hidden_states is not None | ||||||
|  |             else self.generation_config.output_hidden_states | ||||||
|         ) |         ) | ||||||
|         return_dict_in_generate = ( |         return_dict_in_generate = ( | ||||||
|             return_dict_in_generate |             return_dict_in_generate | ||||||
|  | @ -1341,19 +1473,33 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
| 
 | 
 | ||||||
|         # init attention / hidden states / scores tuples |         # init attention / hidden states / scores tuples | ||||||
|         scores = () if (return_dict_in_generate and output_scores) else None |         scores = () if (return_dict_in_generate and output_scores) else None | ||||||
|         decoder_attentions = () if (return_dict_in_generate and output_attentions) else None |         decoder_attentions = ( | ||||||
|         cross_attentions = () if (return_dict_in_generate and output_attentions) else None |             () if (return_dict_in_generate and output_attentions) else None | ||||||
|         decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None |         ) | ||||||
|  |         cross_attentions = ( | ||||||
|  |             () if (return_dict_in_generate and output_attentions) else None | ||||||
|  |         ) | ||||||
|  |         decoder_hidden_states = ( | ||||||
|  |             () if (return_dict_in_generate and output_hidden_states) else None | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # if model is an encoder-decoder, retrieve encoder attention weights and hidden states |         # if model is an encoder-decoder, retrieve encoder attention weights and hidden states | ||||||
|         if return_dict_in_generate and self.config.is_encoder_decoder: |         if return_dict_in_generate and self.config.is_encoder_decoder: | ||||||
|             encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None |             encoder_attentions = ( | ||||||
|  |                 model_kwargs["encoder_outputs"].get("attentions") | ||||||
|  |                 if output_attentions | ||||||
|  |                 else None | ||||||
|  |             ) | ||||||
|             encoder_hidden_states = ( |             encoder_hidden_states = ( | ||||||
|                 model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None |                 model_kwargs["encoder_outputs"].get("hidden_states") | ||||||
|  |                 if output_hidden_states | ||||||
|  |                 else None | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         # keep track of which sequences are already finished |         # keep track of which sequences are already finished | ||||||
|         unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |         unfinished_sequences = torch.ones( | ||||||
|  |             input_ids.shape[0], dtype=torch.long, device=input_ids.device | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         this_peer_finished = False  # used by synced_gpus only |         this_peer_finished = False  # used by synced_gpus only | ||||||
|         # auto-regressive generation |         # auto-regressive generation | ||||||
|  | @ -1394,7 +1540,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|                     scores += (next_token_scores,) |                     scores += (next_token_scores,) | ||||||
|                 if output_attentions: |                 if output_attentions: | ||||||
|                     decoder_attentions += ( |                     decoder_attentions += ( | ||||||
|                         (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) |                         (outputs.decoder_attentions,) | ||||||
|  |                         if self.config.is_encoder_decoder | ||||||
|  |                         else (outputs.attentions,) | ||||||
|                     ) |                     ) | ||||||
|                     if self.config.is_encoder_decoder: |                     if self.config.is_encoder_decoder: | ||||||
|                         cross_attentions += (outputs.cross_attentions,) |                         cross_attentions += (outputs.cross_attentions,) | ||||||
|  | @ -1413,8 +1561,12 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             # finished sentences should have their next token be a padding token |             # finished sentences should have their next token be a padding token | ||||||
|             if eos_token_id is not None: |             if eos_token_id is not None: | ||||||
|                 if pad_token_id is None: |                 if pad_token_id is None: | ||||||
|                     raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |                     raise ValueError( | ||||||
|                 next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) |                         "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." | ||||||
|  |                     ) | ||||||
|  |                 next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( | ||||||
|  |                     1 - unfinished_sequences | ||||||
|  |                 ) | ||||||
| 
 | 
 | ||||||
|             # update generated ids, model inputs, and length for next step |             # update generated ids, model inputs, and length for next step | ||||||
|             input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) |             input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) | ||||||
|  | @ -1427,7 +1579,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
|             # if eos_token was found in one sentence, set sentence to finished |             # if eos_token was found in one sentence, set sentence to finished | ||||||
|             if eos_token_id_tensor is not None: |             if eos_token_id_tensor is not None: | ||||||
|                 unfinished_sequences = unfinished_sequences.mul( |                 unfinished_sequences = unfinished_sequences.mul( | ||||||
|                     next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) |                     next_tokens.tile(eos_token_id_tensor.shape[0], 1) | ||||||
|  |                     .ne(eos_token_id_tensor.unsqueeze(1)) | ||||||
|  |                     .prod(dim=0) | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|                 # stop when each sentence is finished |                 # stop when each sentence is finished | ||||||
|  | @ -1446,7 +1600,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): | ||||||
| 
 | 
 | ||||||
|         return input_ids |         return input_ids | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     # def backward( |     # def backward( | ||||||
|     #     self, |     #     self, | ||||||
|     #     tokenizer, |     #     tokenizer, | ||||||
|  | @ -1539,7 +1692,7 @@ def _rotate_half(x): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def apply_rotary_pos_emb(t, freqs): | def apply_rotary_pos_emb(t, freqs): | ||||||
|     """ Apply rotary embedding to the first rotary_dim of the iput |     """Apply rotary embedding to the first rotary_dim of the iput | ||||||
| 
 | 
 | ||||||
|     Arguments: |     Arguments: | ||||||
|       t (tensor(batch_size, seq_len, n_head, head_dim)): |       t (tensor(batch_size, seq_len, n_head, head_dim)): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue