Update and try seqgpt
This commit is contained in:
parent
0dd2f2bab4
commit
467c78d83d
|
@ -55,7 +55,7 @@ init_kwargs["name_or_path"] = model_dir
|
|||
tokenizer = BloomTokenizerFast(*init_inputs, **init_kwargs)
|
||||
|
||||
model = BloomForCausalLM(config)
|
||||
model = model.from_pretrained(model_dir).cuda().eval()
|
||||
model = model.from_pretrained(model_dir).cuda().train()
|
||||
|
||||
prompt = "输入: 中国的首都在哪里\n输出: "
|
||||
prompt = "输入: 美国的首都在哪里\n输出: "
|
||||
|
|
|
@ -884,6 +884,19 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|||
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
||||
)
|
||||
|
||||
# for test train
|
||||
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
# shift_labels = torch.ones([4,9], requires_grad=True).to(lm_logits.device).to(torch.int64)
|
||||
# batch_size, seq_length, vocab_size = shift_logits.shape
|
||||
# optimizer = torch.optim.SGD(self.parameters(),lr=0.001)
|
||||
# pa = self.transformer.parameters()
|
||||
# loss_fct = CrossEntropyLoss()
|
||||
# loss = loss_fct(
|
||||
# shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
||||
# )
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
|
Loading…
Reference in New Issue