Update qwen model add generator and sample.
This commit is contained in:
		
							parent
							
								
									08f7b75efe
								
							
						
					
					
						commit
						f6538c1111
					
				| 
						 | 
					@ -26,8 +26,9 @@ model = QWenLMHeadModel(config)
 | 
				
			||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
 | 
					tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
 | 
				
			||||||
model = model.from_pretrained(
 | 
					model = model.from_pretrained(
 | 
				
			||||||
    model_dir, device_map="auto", trust_remote_code=True
 | 
					    model_dir, device_map="auto", trust_remote_code=True
 | 
				
			||||||
).eval()
 | 
					).train()
 | 
				
			||||||
 | 
					# model.train()
 | 
				
			||||||
 | 
					# model.zero_grad()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# 可指定不同的生成长度、top_p等相关超参
 | 
					# 可指定不同的生成长度、top_p等相关超参
 | 
				
			||||||
model.generation_config = GenerationConfig.from_pretrained(
 | 
					model.generation_config = GenerationConfig.from_pretrained(
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,6 +6,7 @@
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
import importlib
 | 
					import importlib
 | 
				
			||||||
import math
 | 
					import math
 | 
				
			||||||
 | 
					import inspect
 | 
				
			||||||
import pathlib
 | 
					import pathlib
 | 
				
			||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
 | 
					from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1071,6 +1072,22 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
 | 
					                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # shift_logits = lm_logits[..., :-1, :].contiguous()
 | 
				
			||||||
 | 
					        # loss_fct = CrossEntropyLoss()
 | 
				
			||||||
 | 
					        # loss = loss_fct(
 | 
				
			||||||
 | 
					        #     shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
 | 
				
			||||||
 | 
					        # )
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # optimizer = torch.optim.Adam(self.parameters(), lr=2e-5)
 | 
				
			||||||
 | 
					        # # optimizer = torch.optim.SGD(self.parameters(),lr=0.001)
 | 
				
			||||||
 | 
					        # # pa = self.transformer.parameters()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # loss.backward()
 | 
				
			||||||
 | 
					        # # optimizer.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not return_dict:
 | 
					        if not return_dict:
 | 
				
			||||||
            output = (lm_logits,) + transformer_outputs[1:]
 | 
					            output = (lm_logits,) + transformer_outputs[1:]
 | 
				
			||||||
            return ((loss,) + output) if loss is not None else output
 | 
					            return ((loss,) + output) if loss is not None else output
 | 
				
			||||||
| 
						 | 
					@ -1143,7 +1160,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
                    generation_config=generation_config,
 | 
					                    generation_config=generation_config,
 | 
				
			||||||
                    **kwargs,
 | 
					                    **kwargs,
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        response = decode_tokens(
 | 
					        response = decode_tokens(
 | 
				
			||||||
            outputs[0],
 | 
					            outputs[0],
 | 
				
			||||||
            tokenizer,
 | 
					            tokenizer,
 | 
				
			||||||
| 
						 | 
					@ -1258,7 +1275,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                logits_processor.append(stop_words_logits_processor)
 | 
					                logits_processor.append(stop_words_logits_processor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return super().generate(
 | 
					        return self.generate_base(
 | 
				
			||||||
            inputs,
 | 
					            inputs,
 | 
				
			||||||
            generation_config=generation_config,
 | 
					            generation_config=generation_config,
 | 
				
			||||||
            logits_processor=logits_processor,
 | 
					            logits_processor=logits_processor,
 | 
				
			||||||
| 
						 | 
					@ -1270,6 +1287,404 @@ class QWenLMHeadModel(QWenPreTrainedModel):
 | 
				
			||||||
            **kwargs,
 | 
					            **kwargs,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_base(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        inputs: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        generation_config: Optional[GenerationConfig] = None,
 | 
				
			||||||
 | 
					        logits_processor: Optional[LogitsProcessorList] = None,
 | 
				
			||||||
 | 
					        stopping_criteria: Optional[StoppingCriteriaList] = None,
 | 
				
			||||||
 | 
					        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
 | 
				
			||||||
 | 
					        synced_gpus: Optional[bool] = None,
 | 
				
			||||||
 | 
					        assistant_model: Optional["PreTrainedModel"] = None,
 | 
				
			||||||
 | 
					        streamer: Optional["BaseStreamer"] = None,
 | 
				
			||||||
 | 
					        negative_prompt_ids: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        negative_prompt_attention_mask: Optional[torch.Tensor] = None,
 | 
				
			||||||
 | 
					        **kwargs,
 | 
				
			||||||
 | 
					    ) -> Union[GenerateOutput, torch.LongTensor]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if synced_gpus is None:
 | 
				
			||||||
 | 
					                synced_gpus = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
 | 
				
			||||||
 | 
					        self._validate_model_class()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # priority: `generation_config` argument > `model.generation_config` (the default generation config)
 | 
				
			||||||
 | 
					        if generation_config is None:
 | 
				
			||||||
 | 
					            # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
 | 
				
			||||||
 | 
					            # two conditions must be met
 | 
				
			||||||
 | 
					            # 1) the generation config must have been created from the model config (`_from_model_config` field);
 | 
				
			||||||
 | 
					            # 2) the generation config must have seen no modification since its creation (the hash is the same).
 | 
				
			||||||
 | 
					            if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
 | 
				
			||||||
 | 
					                self.generation_config
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					                new_generation_config = GenerationConfig.from_model_config(self.config)
 | 
				
			||||||
 | 
					                if new_generation_config != self.generation_config:
 | 
				
			||||||
 | 
					                    warnings.warn(
 | 
				
			||||||
 | 
					                        "You have modified the pretrained model configuration to control generation. This is a"
 | 
				
			||||||
 | 
					                        " deprecated strategy to control generation and will be removed soon, in a future version."
 | 
				
			||||||
 | 
					                        " Please use and modify the model generation configuration (see"
 | 
				
			||||||
 | 
					                        " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    self.generation_config = new_generation_config
 | 
				
			||||||
 | 
					            generation_config = self.generation_config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        generation_config = copy.deepcopy(generation_config)
 | 
				
			||||||
 | 
					        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
 | 
				
			||||||
 | 
					        generation_config.validate()
 | 
				
			||||||
 | 
					        self._validate_model_kwargs(model_kwargs.copy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 2. Set generation parameters if not already defined
 | 
				
			||||||
 | 
					        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
 | 
				
			||||||
 | 
					        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
 | 
				
			||||||
 | 
					            if model_kwargs.get("attention_mask", None) is None:
 | 
				
			||||||
 | 
					                logger.warning(
 | 
				
			||||||
 | 
					                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
 | 
				
			||||||
 | 
					                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            eos_token_id = generation_config.eos_token_id
 | 
				
			||||||
 | 
					            if isinstance(eos_token_id, list):
 | 
				
			||||||
 | 
					                eos_token_id = eos_token_id[0]
 | 
				
			||||||
 | 
					            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
 | 
				
			||||||
 | 
					            generation_config.pad_token_id = eos_token_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 3. Define model inputs
 | 
				
			||||||
 | 
					        # inputs_tensor has to be defined
 | 
				
			||||||
 | 
					        # model_input_name is defined if model-specific keyword input is passed
 | 
				
			||||||
 | 
					        # otherwise model_input_name is None
 | 
				
			||||||
 | 
					        # all model-specific keyword inputs are removed from `model_kwargs`
 | 
				
			||||||
 | 
					        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
 | 
				
			||||||
 | 
					            inputs, generation_config.bos_token_id, model_kwargs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        batch_size = inputs_tensor.shape[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 4. Define other model kwargs
 | 
				
			||||||
 | 
					        model_kwargs["output_attentions"] = generation_config.output_attentions
 | 
				
			||||||
 | 
					        model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
 | 
				
			||||||
 | 
					        # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
 | 
				
			||||||
 | 
					        # generating the first new token or not, and we only want to use the embeddings for the first new token)
 | 
				
			||||||
 | 
					        if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
 | 
				
			||||||
 | 
					            model_kwargs["use_cache"] = True
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            model_kwargs["use_cache"] = generation_config.use_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
 | 
				
			||||||
 | 
					        requires_attention_mask = "encoder_outputs" not in model_kwargs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
 | 
				
			||||||
 | 
					            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
 | 
				
			||||||
 | 
					                inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # decoder-only models should use left-padding for generation
 | 
				
			||||||
 | 
					        if not self.config.is_encoder_decoder:
 | 
				
			||||||
 | 
					            # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
 | 
				
			||||||
 | 
					            # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
 | 
				
			||||||
 | 
					            if (
 | 
				
			||||||
 | 
					                generation_config.pad_token_id is not None
 | 
				
			||||||
 | 
					                and len(inputs_tensor.shape) == 2
 | 
				
			||||||
 | 
					                and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
 | 
				
			||||||
 | 
					            ):
 | 
				
			||||||
 | 
					                logger.warning(
 | 
				
			||||||
 | 
					                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
 | 
				
			||||||
 | 
					                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
 | 
				
			||||||
 | 
					            # if model is encoder decoder encoder_outputs are created
 | 
				
			||||||
 | 
					            # and added to `model_kwargs`
 | 
				
			||||||
 | 
					            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
 | 
				
			||||||
 | 
					                inputs_tensor, model_kwargs, model_input_name
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 5. Prepare `input_ids` which will be used for auto-regressive generation
 | 
				
			||||||
 | 
					        if self.config.is_encoder_decoder:
 | 
				
			||||||
 | 
					            input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
 | 
				
			||||||
 | 
					                batch_size=batch_size,
 | 
				
			||||||
 | 
					                model_input_name=model_input_name,
 | 
				
			||||||
 | 
					                model_kwargs=model_kwargs,
 | 
				
			||||||
 | 
					                decoder_start_token_id=generation_config.decoder_start_token_id,
 | 
				
			||||||
 | 
					                bos_token_id=generation_config.bos_token_id,
 | 
				
			||||||
 | 
					                device=inputs_tensor.device,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if streamer is not None:
 | 
				
			||||||
 | 
					            streamer.put(input_ids.cpu())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 6. Prepare `max_length` depending on other stopping criteria.
 | 
				
			||||||
 | 
					        input_ids_length = input_ids.shape[-1]
 | 
				
			||||||
 | 
					        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
 | 
				
			||||||
 | 
					        if generation_config.max_new_tokens is not None:
 | 
				
			||||||
 | 
					            if not has_default_max_length and generation_config.max_length is not None:
 | 
				
			||||||
 | 
					                logger.warning(
 | 
				
			||||||
 | 
					                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
 | 
				
			||||||
 | 
					                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
 | 
				
			||||||
 | 
					                    "Please refer to the documentation for more information. "
 | 
				
			||||||
 | 
					                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            generation_config.max_length = generation_config.max_new_tokens + input_ids_length
 | 
				
			||||||
 | 
					        self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 7. determine generation mode
 | 
				
			||||||
 | 
					        generation_mode = self._get_generation_mode(generation_config, assistant_model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if streamer is not None and (generation_config.num_beams > 1):
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.device.type != input_ids.device.type:
 | 
				
			||||||
 | 
					            warnings.warn(
 | 
				
			||||||
 | 
					                "You are calling .generate() with the `input_ids` being on a device type different"
 | 
				
			||||||
 | 
					                f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
 | 
				
			||||||
 | 
					                f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
 | 
				
			||||||
 | 
					                " Please make sure that you have put `input_ids` to the"
 | 
				
			||||||
 | 
					                f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
 | 
				
			||||||
 | 
					                " running `.generate()`.",
 | 
				
			||||||
 | 
					                UserWarning,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 8. prepare distribution pre_processing samplers
 | 
				
			||||||
 | 
					        logits_processor = self._get_logits_processor(
 | 
				
			||||||
 | 
					            generation_config=generation_config,
 | 
				
			||||||
 | 
					            input_ids_seq_length=input_ids_length,
 | 
				
			||||||
 | 
					            encoder_input_ids=inputs_tensor,
 | 
				
			||||||
 | 
					            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
 | 
				
			||||||
 | 
					            logits_processor=logits_processor,
 | 
				
			||||||
 | 
					            model_kwargs=model_kwargs,
 | 
				
			||||||
 | 
					            negative_prompt_ids=negative_prompt_ids,
 | 
				
			||||||
 | 
					            negative_prompt_attention_mask=negative_prompt_attention_mask,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 9. prepare stopping criteria
 | 
				
			||||||
 | 
					        stopping_criteria = self._get_stopping_criteria(
 | 
				
			||||||
 | 
					            generation_config=generation_config, stopping_criteria=stopping_criteria
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # 10. go into different generation modes
 | 
				
			||||||
 | 
					       
 | 
				
			||||||
 | 
					        # 11. prepare logits warper
 | 
				
			||||||
 | 
					        logits_warper = self._get_logits_warper(generation_config)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 12. expand input_ids with `num_return_sequences` additional sequences per batch
 | 
				
			||||||
 | 
					        input_ids, model_kwargs = self._expand_inputs_for_generation(
 | 
				
			||||||
 | 
					            input_ids=input_ids,
 | 
				
			||||||
 | 
					            expand_size=generation_config.num_return_sequences,
 | 
				
			||||||
 | 
					            is_encoder_decoder=self.config.is_encoder_decoder,
 | 
				
			||||||
 | 
					            **model_kwargs,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # 13. run sample
 | 
				
			||||||
 | 
					        return self.sample_base(
 | 
				
			||||||
 | 
					            input_ids,
 | 
				
			||||||
 | 
					            logits_processor=logits_processor,
 | 
				
			||||||
 | 
					            logits_warper=logits_warper,
 | 
				
			||||||
 | 
					            stopping_criteria=stopping_criteria,
 | 
				
			||||||
 | 
					            pad_token_id=generation_config.pad_token_id,
 | 
				
			||||||
 | 
					            eos_token_id=generation_config.eos_token_id,
 | 
				
			||||||
 | 
					            output_scores=generation_config.output_scores,
 | 
				
			||||||
 | 
					            return_dict_in_generate=generation_config.return_dict_in_generate,
 | 
				
			||||||
 | 
					            synced_gpus=synced_gpus,
 | 
				
			||||||
 | 
					            streamer=streamer,
 | 
				
			||||||
 | 
					            **model_kwargs,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def sample_base(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        input_ids: torch.LongTensor,
 | 
				
			||||||
 | 
					        logits_processor: Optional[LogitsProcessorList] = None,
 | 
				
			||||||
 | 
					        stopping_criteria: Optional[StoppingCriteriaList] = None,
 | 
				
			||||||
 | 
					        logits_warper: Optional[LogitsProcessorList] = None,
 | 
				
			||||||
 | 
					        max_length: Optional[int] = None,
 | 
				
			||||||
 | 
					        pad_token_id: Optional[int] = None,
 | 
				
			||||||
 | 
					        eos_token_id: Optional[Union[int, List[int]]] = None,
 | 
				
			||||||
 | 
					        output_attentions: Optional[bool] = None,
 | 
				
			||||||
 | 
					        output_hidden_states: Optional[bool] = None,
 | 
				
			||||||
 | 
					        output_scores: Optional[bool] = None,
 | 
				
			||||||
 | 
					        return_dict_in_generate: Optional[bool] = None,
 | 
				
			||||||
 | 
					        synced_gpus: bool = False,
 | 
				
			||||||
 | 
					        streamer: Optional["BaseStreamer"] = None,
 | 
				
			||||||
 | 
					        **model_kwargs,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # init values
 | 
				
			||||||
 | 
					        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
 | 
				
			||||||
 | 
					        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
 | 
				
			||||||
 | 
					        # if max_length is not None:
 | 
				
			||||||
 | 
					        #     warnings.warn(
 | 
				
			||||||
 | 
					        #         "`max_length` is deprecated in this function, use"
 | 
				
			||||||
 | 
					        #         " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
 | 
				
			||||||
 | 
					        #         UserWarning,
 | 
				
			||||||
 | 
					        #     )
 | 
				
			||||||
 | 
					        #     stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
 | 
				
			||||||
 | 
					        logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
 | 
				
			||||||
 | 
					        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
 | 
				
			||||||
 | 
					        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
 | 
				
			||||||
 | 
					        if isinstance(eos_token_id, int):
 | 
				
			||||||
 | 
					            eos_token_id = [eos_token_id]
 | 
				
			||||||
 | 
					        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
 | 
				
			||||||
 | 
					        output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
 | 
				
			||||||
 | 
					        output_attentions = (
 | 
				
			||||||
 | 
					            output_attentions if output_attentions is not None else self.generation_config.output_attentions
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        output_hidden_states = (
 | 
				
			||||||
 | 
					            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return_dict_in_generate = (
 | 
				
			||||||
 | 
					            return_dict_in_generate
 | 
				
			||||||
 | 
					            if return_dict_in_generate is not None
 | 
				
			||||||
 | 
					            else self.generation_config.return_dict_in_generate
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # init attention / hidden states / scores tuples
 | 
				
			||||||
 | 
					        scores = () if (return_dict_in_generate and output_scores) else None
 | 
				
			||||||
 | 
					        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
 | 
				
			||||||
 | 
					        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
 | 
				
			||||||
 | 
					        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
 | 
				
			||||||
 | 
					        if return_dict_in_generate and self.config.is_encoder_decoder:
 | 
				
			||||||
 | 
					            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
 | 
				
			||||||
 | 
					            encoder_hidden_states = (
 | 
				
			||||||
 | 
					                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # keep track of which sequences are already finished
 | 
				
			||||||
 | 
					        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        this_peer_finished = False  # used by synced_gpus only
 | 
				
			||||||
 | 
					        # auto-regressive generation
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            # if synced_gpus:
 | 
				
			||||||
 | 
					            #     # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
 | 
				
			||||||
 | 
					            #     # The following logic allows an early break if all peers finished generating their sequence
 | 
				
			||||||
 | 
					            #     this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
 | 
				
			||||||
 | 
					            #     # send 0.0 if we finished, 1.0 otherwise
 | 
				
			||||||
 | 
					            #     dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
 | 
				
			||||||
 | 
					            #     # did all peers finish? the reduced sum will be 0.0 then
 | 
				
			||||||
 | 
					            #     if this_peer_finished_flag.item() == 0.0:
 | 
				
			||||||
 | 
					            #         break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # prepare model inputs
 | 
				
			||||||
 | 
					            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # forward pass to get next token
 | 
				
			||||||
 | 
					            outputs = self(
 | 
				
			||||||
 | 
					                **model_inputs,
 | 
				
			||||||
 | 
					                return_dict=True,
 | 
				
			||||||
 | 
					                output_attentions=output_attentions,
 | 
				
			||||||
 | 
					                output_hidden_states=output_hidden_states,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if synced_gpus and this_peer_finished:
 | 
				
			||||||
 | 
					                continue  # don't waste resources running the code we don't need
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            next_token_logits = outputs.logits[:, -1, :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # pre-process distribution
 | 
				
			||||||
 | 
					            next_token_scores = logits_processor(input_ids, next_token_logits)
 | 
				
			||||||
 | 
					            next_token_scores = logits_warper(input_ids, next_token_scores)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Store scores, attentions and hidden_states when required
 | 
				
			||||||
 | 
					            if return_dict_in_generate:
 | 
				
			||||||
 | 
					                if output_scores:
 | 
				
			||||||
 | 
					                    scores += (next_token_scores,)
 | 
				
			||||||
 | 
					                if output_attentions:
 | 
				
			||||||
 | 
					                    decoder_attentions += (
 | 
				
			||||||
 | 
					                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    if self.config.is_encoder_decoder:
 | 
				
			||||||
 | 
					                        cross_attentions += (outputs.cross_attentions,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if output_hidden_states:
 | 
				
			||||||
 | 
					                    decoder_hidden_states += (
 | 
				
			||||||
 | 
					                        (outputs.decoder_hidden_states,)
 | 
				
			||||||
 | 
					                        if self.config.is_encoder_decoder
 | 
				
			||||||
 | 
					                        else (outputs.hidden_states,)
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # sample
 | 
				
			||||||
 | 
					            probs = nn.functional.softmax(next_token_scores, dim=-1)
 | 
				
			||||||
 | 
					            next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # finished sentences should have their next token be a padding token
 | 
				
			||||||
 | 
					            if eos_token_id is not None:
 | 
				
			||||||
 | 
					                if pad_token_id is None:
 | 
				
			||||||
 | 
					                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
 | 
				
			||||||
 | 
					                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # update generated ids, model inputs, and length for next step
 | 
				
			||||||
 | 
					            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
 | 
				
			||||||
 | 
					            if streamer is not None:
 | 
				
			||||||
 | 
					                streamer.put(next_tokens.cpu())
 | 
				
			||||||
 | 
					            model_kwargs = self._update_model_kwargs_for_generation(
 | 
				
			||||||
 | 
					                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # if eos_token was found in one sentence, set sentence to finished
 | 
				
			||||||
 | 
					            if eos_token_id_tensor is not None:
 | 
				
			||||||
 | 
					                unfinished_sequences = unfinished_sequences.mul(
 | 
				
			||||||
 | 
					                    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # stop when each sentence is finished
 | 
				
			||||||
 | 
					                if unfinished_sequences.max() == 0:
 | 
				
			||||||
 | 
					                    this_peer_finished = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # stop if we exceed the maximum length
 | 
				
			||||||
 | 
					            if stopping_criteria(input_ids, scores):
 | 
				
			||||||
 | 
					                this_peer_finished = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if this_peer_finished and not synced_gpus:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if streamer is not None:
 | 
				
			||||||
 | 
					            streamer.end()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return input_ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # def backward(
 | 
				
			||||||
 | 
					    #     self,
 | 
				
			||||||
 | 
					    #     tokenizer,
 | 
				
			||||||
 | 
					    #     query: str,
 | 
				
			||||||
 | 
					    # ):
 | 
				
			||||||
 | 
					    #     inputs = tokenizer.build_chat_input(query, history=[], role="user")
 | 
				
			||||||
 | 
					    #     inputs = inputs.to(next(self.parameters()).device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     generation_config = copy.deepcopy(self.generation_config)
 | 
				
			||||||
 | 
					    #     inputs_tensor = inputs["input_ids"]
 | 
				
			||||||
 | 
					    #     input_ids = inputs_tensor.repeat_interleave(
 | 
				
			||||||
 | 
					    #         generation_config.num_return_sequences, dim=0
 | 
				
			||||||
 | 
					    #     )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     input_ids_in = input_ids
 | 
				
			||||||
 | 
					    #     batch_size, seq_length = input_ids_in.shape
 | 
				
			||||||
 | 
					    #     position_ids_in = (
 | 
				
			||||||
 | 
					    #         torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
 | 
				
			||||||
 | 
					    #         .unsqueeze(0)
 | 
				
			||||||
 | 
					    #         .repeat(batch_size, 1)
 | 
				
			||||||
 | 
					    #     )
 | 
				
			||||||
 | 
					    #     model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     probs, next_tokens = self.transformer(
 | 
				
			||||||
 | 
					    #         **model_inputs,
 | 
				
			||||||
 | 
					    #         output_hidden_states=None,
 | 
				
			||||||
 | 
					    #         tokenizer=tokenizer,
 | 
				
			||||||
 | 
					    #     )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
 | 
				
			||||||
 | 
					    #     # probs_target = probs
 | 
				
			||||||
 | 
					    #     # probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     loss = probs[0, next_tokens]
 | 
				
			||||||
 | 
					    #     loss.backward()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #     return loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RotaryEmbedding(torch.nn.Module):
 | 
					class RotaryEmbedding(torch.nn.Module):
 | 
				
			||||||
    def __init__(self, dim, base=10000):
 | 
					    def __init__(self, dim, base=10000):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue