Update qwen model add generator and sample.
This commit is contained in:
parent
08f7b75efe
commit
f6538c1111
|
@ -26,8 +26,9 @@ model = QWenLMHeadModel(config)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||||
model = model.from_pretrained(
|
model = model.from_pretrained(
|
||||||
model_dir, device_map="auto", trust_remote_code=True
|
model_dir, device_map="auto", trust_remote_code=True
|
||||||
).eval()
|
).train()
|
||||||
|
# model.train()
|
||||||
|
# model.zero_grad()
|
||||||
|
|
||||||
# 可指定不同的生成长度、top_p等相关超参
|
# 可指定不同的生成长度、top_p等相关超参
|
||||||
model.generation_config = GenerationConfig.from_pretrained(
|
model.generation_config = GenerationConfig.from_pretrained(
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
import copy
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
import math
|
import math
|
||||||
|
import inspect
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
|
||||||
|
|
||||||
|
@ -1071,6 +1072,22 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# shift_labels = torch.ones([1,19]).to(lm_logits.device).to(torch.int64)
|
||||||
|
|
||||||
|
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
# loss_fct = CrossEntropyLoss()
|
||||||
|
# loss = loss_fct(
|
||||||
|
# shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# optimizer = torch.optim.Adam(self.parameters(), lr=2e-5)
|
||||||
|
# # optimizer = torch.optim.SGD(self.parameters(),lr=0.001)
|
||||||
|
# # pa = self.transformer.parameters()
|
||||||
|
|
||||||
|
# loss.backward()
|
||||||
|
# # optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
@ -1258,7 +1275,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
else:
|
else:
|
||||||
logits_processor.append(stop_words_logits_processor)
|
logits_processor.append(stop_words_logits_processor)
|
||||||
|
|
||||||
return super().generate(
|
return self.generate_base(
|
||||||
inputs,
|
inputs,
|
||||||
generation_config=generation_config,
|
generation_config=generation_config,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
@ -1270,6 +1287,404 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def generate_base(
|
||||||
|
self,
|
||||||
|
inputs: Optional[torch.Tensor] = None,
|
||||||
|
generation_config: Optional[GenerationConfig] = None,
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||||||
|
synced_gpus: Optional[bool] = None,
|
||||||
|
assistant_model: Optional["PreTrainedModel"] = None,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
|
negative_prompt_ids: Optional[torch.Tensor] = None,
|
||||||
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[GenerateOutput, torch.LongTensor]:
|
||||||
|
|
||||||
|
if synced_gpus is None:
|
||||||
|
synced_gpus = False
|
||||||
|
|
||||||
|
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
|
||||||
|
self._validate_model_class()
|
||||||
|
|
||||||
|
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
|
||||||
|
if generation_config is None:
|
||||||
|
# legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
|
||||||
|
# two conditions must be met
|
||||||
|
# 1) the generation config must have been created from the model config (`_from_model_config` field);
|
||||||
|
# 2) the generation config must have seen no modification since its creation (the hash is the same).
|
||||||
|
if self.generation_config._from_model_config and self.generation_config._original_object_hash == hash(
|
||||||
|
self.generation_config
|
||||||
|
):
|
||||||
|
new_generation_config = GenerationConfig.from_model_config(self.config)
|
||||||
|
if new_generation_config != self.generation_config:
|
||||||
|
warnings.warn(
|
||||||
|
"You have modified the pretrained model configuration to control generation. This is a"
|
||||||
|
" deprecated strategy to control generation and will be removed soon, in a future version."
|
||||||
|
" Please use and modify the model generation configuration (see"
|
||||||
|
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
|
||||||
|
)
|
||||||
|
self.generation_config = new_generation_config
|
||||||
|
generation_config = self.generation_config
|
||||||
|
|
||||||
|
generation_config = copy.deepcopy(generation_config)
|
||||||
|
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
|
||||||
|
generation_config.validate()
|
||||||
|
self._validate_model_kwargs(model_kwargs.copy())
|
||||||
|
|
||||||
|
# 2. Set generation parameters if not already defined
|
||||||
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
|
||||||
|
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
|
||||||
|
if model_kwargs.get("attention_mask", None) is None:
|
||||||
|
logger.warning(
|
||||||
|
"The attention mask and the pad token id were not set. As a consequence, you may observe "
|
||||||
|
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
|
||||||
|
)
|
||||||
|
eos_token_id = generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, list):
|
||||||
|
eos_token_id = eos_token_id[0]
|
||||||
|
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
|
||||||
|
generation_config.pad_token_id = eos_token_id
|
||||||
|
|
||||||
|
# 3. Define model inputs
|
||||||
|
# inputs_tensor has to be defined
|
||||||
|
# model_input_name is defined if model-specific keyword input is passed
|
||||||
|
# otherwise model_input_name is None
|
||||||
|
# all model-specific keyword inputs are removed from `model_kwargs`
|
||||||
|
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
|
||||||
|
inputs, generation_config.bos_token_id, model_kwargs
|
||||||
|
)
|
||||||
|
batch_size = inputs_tensor.shape[0]
|
||||||
|
|
||||||
|
# 4. Define other model kwargs
|
||||||
|
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||||
|
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
|
||||||
|
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
|
||||||
|
# generating the first new token or not, and we only want to use the embeddings for the first new token)
|
||||||
|
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
|
||||||
|
model_kwargs["use_cache"] = True
|
||||||
|
else:
|
||||||
|
model_kwargs["use_cache"] = generation_config.use_cache
|
||||||
|
|
||||||
|
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
|
requires_attention_mask = "encoder_outputs" not in model_kwargs
|
||||||
|
|
||||||
|
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
|
||||||
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
|
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder-only models should use left-padding for generation
|
||||||
|
if not self.config.is_encoder_decoder:
|
||||||
|
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
|
||||||
|
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
|
||||||
|
if (
|
||||||
|
generation_config.pad_token_id is not None
|
||||||
|
and len(inputs_tensor.shape) == 2
|
||||||
|
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"A decoder-only architecture is being used, but right-padding was detected! For correct "
|
||||||
|
"generation results, please set `padding_side='left'` when initializing the tokenizer."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
||||||
|
# if model is encoder decoder encoder_outputs are created
|
||||||
|
# and added to `model_kwargs`
|
||||||
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
||||||
|
inputs_tensor, model_kwargs, model_input_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Prepare `input_ids` which will be used for auto-regressive generation
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
|
||||||
|
batch_size=batch_size,
|
||||||
|
model_input_name=model_input_name,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
decoder_start_token_id=generation_config.decoder_start_token_id,
|
||||||
|
bos_token_id=generation_config.bos_token_id,
|
||||||
|
device=inputs_tensor.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(input_ids.cpu())
|
||||||
|
|
||||||
|
# 6. Prepare `max_length` depending on other stopping criteria.
|
||||||
|
input_ids_length = input_ids.shape[-1]
|
||||||
|
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||||||
|
if generation_config.max_new_tokens is not None:
|
||||||
|
if not has_default_max_length and generation_config.max_length is not None:
|
||||||
|
logger.warning(
|
||||||
|
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||||||
|
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||||||
|
"Please refer to the documentation for more information. "
|
||||||
|
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
|
||||||
|
)
|
||||||
|
generation_config.max_length = generation_config.max_new_tokens + input_ids_length
|
||||||
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||||
|
|
||||||
|
# 7. determine generation mode
|
||||||
|
generation_mode = self._get_generation_mode(generation_config, assistant_model)
|
||||||
|
|
||||||
|
if streamer is not None and (generation_config.num_beams > 1):
|
||||||
|
raise ValueError(
|
||||||
|
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.device.type != input_ids.device.type:
|
||||||
|
warnings.warn(
|
||||||
|
"You are calling .generate() with the `input_ids` being on a device type different"
|
||||||
|
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
|
||||||
|
f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
|
||||||
|
" Please make sure that you have put `input_ids` to the"
|
||||||
|
f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
|
||||||
|
" running `.generate()`.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. prepare distribution pre_processing samplers
|
||||||
|
logits_processor = self._get_logits_processor(
|
||||||
|
generation_config=generation_config,
|
||||||
|
input_ids_seq_length=input_ids_length,
|
||||||
|
encoder_input_ids=inputs_tensor,
|
||||||
|
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
|
negative_prompt_ids=negative_prompt_ids,
|
||||||
|
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 9. prepare stopping criteria
|
||||||
|
stopping_criteria = self._get_stopping_criteria(
|
||||||
|
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||||||
|
)
|
||||||
|
# 10. go into different generation modes
|
||||||
|
|
||||||
|
# 11. prepare logits warper
|
||||||
|
logits_warper = self._get_logits_warper(generation_config)
|
||||||
|
|
||||||
|
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
|
||||||
|
input_ids, model_kwargs = self._expand_inputs_for_generation(
|
||||||
|
input_ids=input_ids,
|
||||||
|
expand_size=generation_config.num_return_sequences,
|
||||||
|
is_encoder_decoder=self.config.is_encoder_decoder,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 13. run sample
|
||||||
|
return self.sample_base(
|
||||||
|
input_ids,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
logits_warper=logits_warper,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
|
pad_token_id=generation_config.pad_token_id,
|
||||||
|
eos_token_id=generation_config.eos_token_id,
|
||||||
|
output_scores=generation_config.output_scores,
|
||||||
|
return_dict_in_generate=generation_config.return_dict_in_generate,
|
||||||
|
synced_gpus=synced_gpus,
|
||||||
|
streamer=streamer,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_base(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor,
|
||||||
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
|
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||||||
|
logits_warper: Optional[LogitsProcessorList] = None,
|
||||||
|
max_length: Optional[int] = None,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[Union[int, List[int]]] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_scores: Optional[bool] = None,
|
||||||
|
return_dict_in_generate: Optional[bool] = None,
|
||||||
|
synced_gpus: bool = False,
|
||||||
|
streamer: Optional["BaseStreamer"] = None,
|
||||||
|
**model_kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
# init values
|
||||||
|
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||||||
|
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||||||
|
# if max_length is not None:
|
||||||
|
# warnings.warn(
|
||||||
|
# "`max_length` is deprecated in this function, use"
|
||||||
|
# " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
|
||||||
|
# UserWarning,
|
||||||
|
# )
|
||||||
|
# stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
||||||
|
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
|
||||||
|
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
||||||
|
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
||||||
|
if isinstance(eos_token_id, int):
|
||||||
|
eos_token_id = [eos_token_id]
|
||||||
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
|
||||||
|
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions if output_attentions is not None else self.generation_config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict_in_generate = (
|
||||||
|
return_dict_in_generate
|
||||||
|
if return_dict_in_generate is not None
|
||||||
|
else self.generation_config.return_dict_in_generate
|
||||||
|
)
|
||||||
|
|
||||||
|
# init attention / hidden states / scores tuples
|
||||||
|
scores = () if (return_dict_in_generate and output_scores) else None
|
||||||
|
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
||||||
|
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
||||||
|
|
||||||
|
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
||||||
|
if return_dict_in_generate and self.config.is_encoder_decoder:
|
||||||
|
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
||||||
|
encoder_hidden_states = (
|
||||||
|
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# keep track of which sequences are already finished
|
||||||
|
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
|
||||||
|
|
||||||
|
this_peer_finished = False # used by synced_gpus only
|
||||||
|
# auto-regressive generation
|
||||||
|
while True:
|
||||||
|
# if synced_gpus:
|
||||||
|
# # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
|
||||||
|
# # The following logic allows an early break if all peers finished generating their sequence
|
||||||
|
# this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
|
||||||
|
# # send 0.0 if we finished, 1.0 otherwise
|
||||||
|
# dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
|
||||||
|
# # did all peers finish? the reduced sum will be 0.0 then
|
||||||
|
# if this_peer_finished_flag.item() == 0.0:
|
||||||
|
# break
|
||||||
|
|
||||||
|
# prepare model inputs
|
||||||
|
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||||||
|
|
||||||
|
# forward pass to get next token
|
||||||
|
outputs = self(
|
||||||
|
**model_inputs,
|
||||||
|
return_dict=True,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
if synced_gpus and this_peer_finished:
|
||||||
|
continue # don't waste resources running the code we don't need
|
||||||
|
|
||||||
|
next_token_logits = outputs.logits[:, -1, :]
|
||||||
|
|
||||||
|
# pre-process distribution
|
||||||
|
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||||
|
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||||
|
|
||||||
|
# Store scores, attentions and hidden_states when required
|
||||||
|
if return_dict_in_generate:
|
||||||
|
if output_scores:
|
||||||
|
scores += (next_token_scores,)
|
||||||
|
if output_attentions:
|
||||||
|
decoder_attentions += (
|
||||||
|
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
||||||
|
)
|
||||||
|
if self.config.is_encoder_decoder:
|
||||||
|
cross_attentions += (outputs.cross_attentions,)
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
decoder_hidden_states += (
|
||||||
|
(outputs.decoder_hidden_states,)
|
||||||
|
if self.config.is_encoder_decoder
|
||||||
|
else (outputs.hidden_states,)
|
||||||
|
)
|
||||||
|
|
||||||
|
# sample
|
||||||
|
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||||||
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
|
||||||
|
# finished sentences should have their next token be a padding token
|
||||||
|
if eos_token_id is not None:
|
||||||
|
if pad_token_id is None:
|
||||||
|
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||||
|
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||||
|
|
||||||
|
# update generated ids, model inputs, and length for next step
|
||||||
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.put(next_tokens.cpu())
|
||||||
|
model_kwargs = self._update_model_kwargs_for_generation(
|
||||||
|
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||||||
|
)
|
||||||
|
|
||||||
|
# if eos_token was found in one sentence, set sentence to finished
|
||||||
|
if eos_token_id_tensor is not None:
|
||||||
|
unfinished_sequences = unfinished_sequences.mul(
|
||||||
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# stop when each sentence is finished
|
||||||
|
if unfinished_sequences.max() == 0:
|
||||||
|
this_peer_finished = True
|
||||||
|
|
||||||
|
# stop if we exceed the maximum length
|
||||||
|
if stopping_criteria(input_ids, scores):
|
||||||
|
this_peer_finished = True
|
||||||
|
|
||||||
|
if this_peer_finished and not synced_gpus:
|
||||||
|
break
|
||||||
|
|
||||||
|
if streamer is not None:
|
||||||
|
streamer.end()
|
||||||
|
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
|
# def backward(
|
||||||
|
# self,
|
||||||
|
# tokenizer,
|
||||||
|
# query: str,
|
||||||
|
# ):
|
||||||
|
# inputs = tokenizer.build_chat_input(query, history=[], role="user")
|
||||||
|
# inputs = inputs.to(next(self.parameters()).device)
|
||||||
|
|
||||||
|
# generation_config = copy.deepcopy(self.generation_config)
|
||||||
|
# inputs_tensor = inputs["input_ids"]
|
||||||
|
# input_ids = inputs_tensor.repeat_interleave(
|
||||||
|
# generation_config.num_return_sequences, dim=0
|
||||||
|
# )
|
||||||
|
|
||||||
|
# input_ids_in = input_ids
|
||||||
|
# batch_size, seq_length = input_ids_in.shape
|
||||||
|
# position_ids_in = (
|
||||||
|
# torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||||||
|
# .unsqueeze(0)
|
||||||
|
# .repeat(batch_size, 1)
|
||||||
|
# )
|
||||||
|
# model_inputs = {"input_ids": input_ids_in, "position_ids": position_ids_in}
|
||||||
|
|
||||||
|
# probs, next_tokens = self.transformer(
|
||||||
|
# **model_inputs,
|
||||||
|
# output_hidden_states=None,
|
||||||
|
# tokenizer=tokenizer,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
# # probs_target = probs
|
||||||
|
# # probs_target[0, next_tokens] = probs_target[0, next_tokens] * 1.1
|
||||||
|
|
||||||
|
# loss = probs[0, next_tokens]
|
||||||
|
# loss.backward()
|
||||||
|
|
||||||
|
# return loss
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(torch.nn.Module):
|
class RotaryEmbedding(torch.nn.Module):
|
||||||
def __init__(self, dim, base=10000):
|
def __init__(self, dim, base=10000):
|
||||||
|
|
Loading…
Reference in New Issue