From 90cb0fe2366a5c5f9865b725bef769bb320a2de5 Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 7 Jan 2024 16:53:53 +0800 Subject: [PATCH] Refine model of qwen. --- qwen/modeling_qwen.py | 487 +++++------------------------------------- 1 file changed, 51 insertions(+), 436 deletions(-) diff --git a/qwen/modeling_qwen.py b/qwen/modeling_qwen.py index d15aefc..a3138a6 100644 --- a/qwen/modeling_qwen.py +++ b/qwen/modeling_qwen.py @@ -36,12 +36,6 @@ except ImportError: from torch import nn SUPPORT_CUDA = torch.cuda.is_available() -SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() -SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 -SUPPORT_TORCH2 = ( - hasattr(torch, "__version__") and int(torch.__version__.split(".")[0]) >= 2 -) - from configuration_qwen import QWenConfig from qwen_generation_utils import ( @@ -69,37 +63,6 @@ Pass argument `stream` to model.chat() is buggy, deprecated, and marked for remo 向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 """ -apply_rotary_emb_func = None -rms_norm = None - - -def quantize_cache_v(fdata, bits, qmax, qmin): - # b, s, head, h-dim->b, head, s, h-dim - qtype = torch.uint8 - device = fdata.device - shape = fdata.shape - - fdata_cal = torch.flatten(fdata, 2) - fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) - fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) - # Compute params - if qmax.device != fmax.device: - qmax = qmax.to(device) - qmin = qmin.to(device) - scale = (fmax - fmin) / (qmax - qmin) - zero = qmin - fmin / scale - scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() - zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() - # Quantize - res_data = fdata / scale + zero - qdata = torch.clamp(res_data, qmin, qmax).to(qtype) - return qdata.contiguous(), scale, zero - - -def dequantize_cache_torch(qdata, scale, zero): - data = scale * (qdata - zero) - return data - class QWenAttention(nn.Module): def __init__(self, config): @@ -128,9 +91,7 @@ class QWenAttention(nn.Module): config.hidden_size, self.projection_size, bias=not config.no_bias ) - self.is_fp32 = not (config.bf16 or config.fp16) - - self.bf16 = config.bf16 + self.is_fp32 = True self.use_dynamic_ntk = config.use_dynamic_ntk self.use_logn_attn = config.use_logn_attn @@ -146,128 +107,13 @@ class QWenAttention(nn.Module): self.softmax_in_fp32 = ( config.softmax_in_fp32 if hasattr(config, "softmax_in_fp32") else False ) - self.use_cache_quantization = ( - config.use_cache_quantization - if hasattr(config, "use_cache_quantization") - else False - ) self.use_cache_kernel = ( config.use_cache_kernel if hasattr(config, "use_cache_kernel") else False ) cache_dtype = torch.float - if self.bf16: - cache_dtype = torch.bfloat16 - elif config.fp16: - cache_dtype = torch.float16 self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) - if config.use_cache_quantization and config.use_cache_kernel: - # pre check if the support files existing - module_root = pathlib.Path(__file__).parent - src_files = ( - "cache_autogptq_cuda_256.cpp", - "cache_autogptq_cuda_kernel_256.cu", - ) - if any(not (module_root / src).is_file() for src in src_files): - warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") - self.cache_kernels = None - else: - try: - from .cpp_kernels import cache_autogptq_cuda_256 - - self.cache_kernels = cache_autogptq_cuda_256 - except ImportError: - warnings.warn("Failed to import KV cache kernels.") - self.cache_kernels = None - - def _attn( - self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None - ): - device = query.device - if self.use_cache_quantization: - qk, qk_scale, qk_zero = key - if self.use_cache_kernel and self.cache_kernels is not None: - shape = query.shape[:-1] + (qk.shape[-2],) - attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) - self.cache_kernels.vecquant8matmul_batched_faster_old( - query.contiguous() - if query.dtype == torch.float16 - else query.to(torch.float16).contiguous(), - qk.transpose(-1, -2).contiguous(), - attn_weights, - qk_scale.contiguous() - if qk_scale.dtype == torch.float16 - else qk_scale.to(torch.float16).contiguous(), - qk_zero.contiguous() - if qk_zero.dtype == torch.float16 - else qk_zero.to(torch.float16).contiguous(), - ) - # attn_weights = attn_weights.to(query.dtype).contiguous() - else: - key = dequantize_cache_torch(qk, qk_scale, qk_zero) - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - else: - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - if self.use_cache_quantization: - size_temp = value[0].size(-1) - else: - size_temp = value.size(-1) - attn_weights = attn_weights / (size_temp**0.5) - - mask_value = torch.finfo(attn_weights.dtype).min - if causal_mask is not None: - attn_weights = torch.where( - causal_mask, attn_weights.to(attn_weights.dtype), mask_value - ) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - if self.softmax_in_fp32: - attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) - else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = attn_weights.type(query.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - if self.use_cache_quantization: - qv, qv_scale, qv_zero = value - if self.use_cache_kernel and self.cache_kernels is not None: - shape = attn_weights.shape[:-1] + (query.shape[-1],) - attn_output = torch.zeros(shape, dtype=torch.float16, device=device) - self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( - attn_weights.contiguous() - if attn_weights.dtype == torch.float16 - else attn_weights.to(torch.float16).contiguous(), - qv.contiguous(), # dtype: int32 - attn_output, - qv_scale.contiguous() - if qv_scale.dtype == torch.float16 - else qv_scale.to(torch.float16).contiguous(), - qv_zero.contiguous() - if qv_zero.dtype == torch.float16 - else qv_zero.to(torch.float16).contiguous(), - ) - if attn_output.dtype != query.dtype: - attn_output = attn_output.to(query.dtype) - attn_weights = attn_weights.to(query.dtype) - else: - value = dequantize_cache_torch(qv, qv_scale, qv_zero) - attn_output = torch.matmul(attn_weights, value) - else: - attn_output = torch.matmul(attn_weights, value) - - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights - def _split_heads(self, tensor, num_heads, attn_head_size): new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(new_shape) @@ -323,59 +169,25 @@ class QWenAttention(nn.Module): query = torch.cat(query_list, dim=0) key = torch.cat(key_list, dim=0) - if self.use_cache_quantization: - key = quantize_cache_v( - key.permute(0, 2, 1, 3), - bits=8, - qmin=self.cache_qmin, - qmax=self.cache_qmax, - ) - value = quantize_cache_v( - value.permute(0, 2, 1, 3), - bits=8, - qmin=self.cache_qmin, - qmax=self.cache_qmax, - ) - if layer_past is not None: past_key, past_value = layer_past[0], layer_past[1] - if self.use_cache_quantization: - # use_cache_quantization: - # present=((q_key,key_scale,key_zero_point), - # (q_value,value_scale,value_zero_point)) - key = ( - torch.cat((past_key[0], key[0]), dim=2), - torch.cat((past_key[1], key[1]), dim=2), - torch.cat((past_key[2], key[2]), dim=2), - ) - value = ( - torch.cat((past_value[0], value[0]), dim=2), - torch.cat((past_value[1], value[1]), dim=2), - torch.cat((past_value[2], value[2]), dim=2), - ) - else: - # not use_cache_quantization: - # present=(key,value) - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) + + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) if use_cache: present = (key, value) else: present = None - key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) + key_size = key.size(1) if key_size > self.seq_length and self.use_logn_attn and not self.training: - if self.use_cache_quantization: - seq_start = key[0].size(2) - query.size(1) - seq_end = key[0].size(2) - else: - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) + seq_start = key.size(1) - query.size(1) + seq_end = key.size(1) logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) query = query * logn_tensor.expand_as(query) - key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) + key_size = key.size(1) if query.size(1) == key_size: causal_mask = torch.tril( torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) @@ -383,39 +195,31 @@ class QWenAttention(nn.Module): else: causal_mask = None query = query.permute(0, 2, 1, 3) - if not self.use_cache_quantization: - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) - if not self.use_cache_quantization and SUPPORT_TORCH2: - if attention_mask is not None: - attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1) - if causal_mask is not None: - attention_mask = attention_mask.masked_fill( - ~causal_mask, torch.finfo(query.dtype).min - ) - else: - attention_mask = causal_mask - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask - ).transpose(1, 2) - attn_weight = None + if attention_mask is not None: + attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1) + if causal_mask is not None: + attention_mask = attention_mask.masked_fill( + ~causal_mask, torch.finfo(query.dtype).min + ) else: - attn_output, attn_weight = self._attn( - query, key, value, causal_mask, attention_mask, head_mask - ) + attention_mask = causal_mask + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ).transpose(1, 2) + attn_weight = None + context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(context_layer) outputs = (attn_output, present) if output_attentions: - if not self.use_cache_quantization and SUPPORT_TORCH2: - raise ValueError( - "Cannot output attentions while using scaled_dot_product_attention" - ) - else: - outputs += (attn_weight,) + raise ValueError( + "Cannot output attentions while using scaled_dot_product_attention" + ) return outputs @@ -444,7 +248,6 @@ class QWenBlock(nn.Module): def __init__(self, config): super().__init__() hidden_size = config.hidden_size - self.bf16 = config.bf16 self.ln_1 = RMSNorm( hidden_size, @@ -549,11 +352,6 @@ class QWenModel(QWenPreTrainedModel): self.vocab_size = config.vocab_size self.num_hidden_layers = config.num_hidden_layers self.embed_dim = config.hidden_size - self.use_cache_quantization = ( - self.config.use_cache_quantization - if hasattr(self.config, "use_cache_quantization") - else False - ) self.gradient_checkpointing = False self.use_dynamic_ntk = config.use_dynamic_ntk @@ -571,7 +369,7 @@ class QWenModel(QWenPreTrainedModel): dim = self.rotary_ndims if self.rotary_ndims is not None else config.kv_channels self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) - self.is_fp32 = not (config.bf16 or config.fp16) + self.is_fp32 = True self.h = nn.ModuleList( [QWenBlock(config) for i in range(config.num_hidden_layers)] @@ -651,10 +449,7 @@ class QWenModel(QWenPreTrainedModel): past_length = 0 past_key_values = tuple([None] * len(self.h)) else: - if self.use_cache_quantization: - past_length = past_key_values[0][0][0].size(2) - else: - past_length = past_key_values[0][0].size(-2) + past_length = past_key_values[0][0].size(-2) if position_ids is None: position_ids = torch.arange( past_length, @@ -682,10 +477,7 @@ class QWenModel(QWenPreTrainedModel): kv_seq_len = hidden_states.size()[1] if past_key_values[0] is not None: # past key values[0][0] shape: bs * seq_len * head_num * dim - if self.use_cache_quantization: - kv_seq_len += past_key_values[0][0][0].shape[2] - else: - kv_seq_len += past_key_values[0][0].shape[1] + kv_seq_len += past_key_values[0][0].shape[1] if self.training or not self.use_dynamic_ntk: ntk_alpha_list = [1.0] @@ -796,55 +588,9 @@ class QWenLMHeadModel(QWenPreTrainedModel): def __init__(self, config): super().__init__(config) - assert ( - config.bf16 + config.fp16 + config.fp32 <= 1 - ), 'Only one of "bf16", "fp16", "fp32" can be true' - - autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 - - if autoset_precision: - if SUPPORT_BF16: - logger.warn( - "The model is automatically converting to bf16 for faster inference. " - 'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".' - ) - config.bf16 = True - elif SUPPORT_FP16: - logger.warn( - "The model is automatically converting to fp16 for faster inference. " - 'If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to "AutoModelForCausalLM.from_pretrained".' - ) - config.fp16 = True - else: - config.fp32 = True - - if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: - logger.warn( - 'Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in "AutoModelForCausalLM.from_pretrained".' - ) - if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: - logger.warn( - "Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster" - ) - if config.fp32: - if SUPPORT_BF16: - logger.warn( - 'Your device support faster inference by passing bf16=True in "AutoModelForCausalLM.from_pretrained".' - ) - elif SUPPORT_FP16: - logger.warn( - 'Your device support faster inference by passing fp16=True in "AutoModelForCausalLM.from_pretrained".' - ) self.transformer = QWenModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - if config.bf16: - self.transformer.bfloat16() - self.lm_head.bfloat16() - if config.fp16: - self.transformer.half() - self.lm_head.half() self.post_init() def get_output_embeddings(self): @@ -928,13 +674,13 @@ class QWenLMHeadModel(QWenPreTrainedModel): shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) - # shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64) - # shift_logits = lm_logits[..., :-1, :].contiguous() - # loss_fct = CrossEntropyLoss() - # loss = loss_fct( - # shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - # ) - # loss.backward() + shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64) + shift_logits = lm_logits[..., :-1, :].contiguous() + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + loss.backward() if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -948,18 +694,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): attentions=transformer_outputs.attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - return tuple( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ) - for layer_past in past_key_values - ) - def chat( self, tokenizer: PreTrainedTokenizer, @@ -1171,9 +905,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): negative_prompt_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: - if synced_gpus is None: - synced_gpus = False - # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() @@ -1271,44 +1002,13 @@ class QWenLMHeadModel(QWenPreTrainedModel): generation_config.eos_token_id, ) - # decoder-only models should use left-padding for generation - if not self.config.is_encoder_decoder: - # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` - # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. - if ( - generation_config.pad_token_id is not None - and len(inputs_tensor.shape) == 2 - and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) - > 0 - ): - logger.warning( - "A decoder-only architecture is being used, but right-padding was detected! For correct " - "generation results, please set `padding_side='left'` when initializing the tokenizer." - ) - - if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: - # if model is encoder decoder encoder_outputs are created - # and added to `model_kwargs` - model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, model_kwargs, model_input_name - ) - # 5. Prepare `input_ids` which will be used for auto-regressive generation - if self.config.is_encoder_decoder: - input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( - batch_size=batch_size, - model_input_name=model_input_name, - model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, - bos_token_id=generation_config.bos_token_id, - device=inputs_tensor.device, - ) - else: - input_ids = ( - inputs_tensor - if model_input_name == "input_ids" - else model_kwargs.pop("input_ids") - ) + + input_ids = ( + inputs_tensor + if model_input_name == "input_ids" + else model_kwargs.pop("input_ids") + ) if streamer is not None: streamer.put(input_ids.cpu()) @@ -1378,7 +1078,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, + is_encoder_decoder=False, **model_kwargs, ) @@ -1483,19 +1183,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): () if (return_dict_in_generate and output_hidden_states) else None ) - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - # keep track of which sequences are already finished unfinished_sequences = torch.ones( input_ids.shape[0], dtype=torch.long, device=input_ids.device @@ -1504,16 +1191,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): this_peer_finished = False # used by synced_gpus only # auto-regressive generation while True: - # if synced_gpus: - # # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # # The following logic allows an early break if all peers finished generating their sequence - # this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # # send 0.0 if we finished, 1.0 otherwise - # dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # # did all peers finish? the reduced sum will be 0.0 then - # if this_peer_finished_flag.item() == 0.0: - # break - # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -1525,9 +1202,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): output_hidden_states=output_hidden_states, ) - if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits[:, -1, :] # pre-process distribution @@ -1539,20 +1213,10 @@ class QWenLMHeadModel(QWenPreTrainedModel): if output_scores: scores += (next_token_scores,) if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) + decoder_attentions += (outputs.attentions,) if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) + decoder_hidden_states += (outputs.hidden_states,) # sample probs = nn.functional.softmax(next_token_scores, dim=-1) @@ -1573,7 +1237,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + outputs, model_kwargs, is_encoder_decoder=False ) # if eos_token was found in one sentence, set sentence to finished @@ -1592,7 +1256,7 @@ class QWenLMHeadModel(QWenPreTrainedModel): if stopping_criteria(input_ids, scores): this_peer_finished = True - if this_peer_finished and not synced_gpus: + if this_peer_finished: break if streamer is not None: @@ -1600,44 +1264,6 @@ class QWenLMHeadModel(QWenPreTrainedModel): return input_ids - # def backward( - # self, - # tokenizer, - # query: str, - # ): - # inputs = tokenizer.build_chat_input(query, history=[], role="user") - # inputs = inputs.to(next(self.parameters()).device) - - # generation_config = copy.deepcopy(self.generation_config) - # inputs_tensor = inputs["input_ids"] - # input_ids = inputs_tensor.repeat_interleave( - # generation_config.num_return_sequences, dim=0 - # ) - - # input_ids_in = input_ids - # batch_size, seq_length = input_ids_in.shape - # position_ids_in = ( - # torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - # .unsqueeze(0) - # .repeat(batch_size, 1) - # ) - # model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in} - - # probs, next_tokens = self.transformer( - # **model_inputs, - # output_hidden_states=None, - # tokenizer=tokenizer, - # ) - - # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - # # probs_target = probs - # # probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1 - - # loss = probs[0, next_tokens] - # loss.backward() - - # return loss - class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000): @@ -1703,17 +1329,9 @@ def apply_rotary_pos_emb(t, freqs): rot_dim = freqs[0].shape[-1] cos, sin = freqs t_float = t.float() - if apply_rotary_emb_func is not None and t.is_cuda: - # apply_rotary_emb in flash_attn requires cos/sin to be of - # shape (seqlen, rotary_dim / 2) and apply rotary embedding - # to the first rotary_dim of the input - cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2] - sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2] - return apply_rotary_emb_func(t_float, cos, sin).type_as(t) - else: - t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] - t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin) - return torch.cat((t_rot, t_pass), dim=-1).type_as(t) + t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] + t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin) + return torch.cat((t_rot, t_pass), dim=-1).type_as(t) class RMSNorm(torch.nn.Module): @@ -1726,8 +1344,5 @@ class RMSNorm(torch.nn.Module): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): - if rms_norm is not None and x.is_cuda: - return rms_norm(x, self.weight, self.eps) - else: - output = self._norm(x.float()).type_as(x) - return output * self.weight + output = self._norm(x.float()).type_as(x) + return output * self.weight