Update and try seqgpt

This commit is contained in:
Colin 2024-01-07 15:06:39 +08:00
parent 0dd2f2bab4
commit 467c78d83d
2 changed files with 14 additions and 1 deletions

View File

@ -55,7 +55,7 @@ init_kwargs["name_or_path"] = model_dir
tokenizer = BloomTokenizerFast(*init_inputs, **init_kwargs) tokenizer = BloomTokenizerFast(*init_inputs, **init_kwargs)
model = BloomForCausalLM(config) model = BloomForCausalLM(config)
model = model.from_pretrained(model_dir).cuda().eval() model = model.from_pretrained(model_dir).cuda().train()
prompt = "输入: 中国的首都在哪里\n输出: " prompt = "输入: 中国的首都在哪里\n输出: "
prompt = "输入: 美国的首都在哪里\n输出: " prompt = "输入: 美国的首都在哪里\n输出: "

View File

@ -884,6 +884,19 @@ class BloomForCausalLM(BloomPreTrainedModel):
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length) shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
) )
# for test train
# shift_logits = lm_logits[..., :-1, :].contiguous()
# shift_labels = torch.ones([4,9], requires_grad=True).to(lm_logits.device).to(torch.int64)
# batch_size, seq_length, vocab_size = shift_logits.shape
# optimizer = torch.optim.SGD(self.parameters(),lr=0.001)
# pa = self.transformer.parameters()
# loss_fct = CrossEntropyLoss()
# loss = loss_fct(
# shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
# )
# loss.backward()
# optimizer.step()
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output