Update and try seqgpt
This commit is contained in:
parent
0dd2f2bab4
commit
467c78d83d
|
@ -55,7 +55,7 @@ init_kwargs["name_or_path"] = model_dir
|
||||||
tokenizer = BloomTokenizerFast(*init_inputs, **init_kwargs)
|
tokenizer = BloomTokenizerFast(*init_inputs, **init_kwargs)
|
||||||
|
|
||||||
model = BloomForCausalLM(config)
|
model = BloomForCausalLM(config)
|
||||||
model = model.from_pretrained(model_dir).cuda().eval()
|
model = model.from_pretrained(model_dir).cuda().train()
|
||||||
|
|
||||||
prompt = "输入: 中国的首都在哪里\n输出: "
|
prompt = "输入: 中国的首都在哪里\n输出: "
|
||||||
prompt = "输入: 美国的首都在哪里\n输出: "
|
prompt = "输入: 美国的首都在哪里\n输出: "
|
||||||
|
|
|
@ -884,6 +884,19 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
||||||
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# for test train
|
||||||
|
# shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||||
|
# shift_labels = torch.ones([4,9], requires_grad=True).to(lm_logits.device).to(torch.int64)
|
||||||
|
# batch_size, seq_length, vocab_size = shift_logits.shape
|
||||||
|
# optimizer = torch.optim.SGD(self.parameters(),lr=0.001)
|
||||||
|
# pa = self.transformer.parameters()
|
||||||
|
# loss_fct = CrossEntropyLoss()
|
||||||
|
# loss = loss_fct(
|
||||||
|
# shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
|
||||||
|
# )
|
||||||
|
# loss.backward()
|
||||||
|
# optimizer.step()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (lm_logits,) + transformer_outputs[1:]
|
output = (lm_logits,) + transformer_outputs[1:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
Loading…
Reference in New Issue