Refien sample code.
This commit is contained in:
parent
10268c4414
commit
72787b9268
|
@ -5,7 +5,7 @@
|
||||||
|
|
||||||
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
|
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
|
||||||
|
|
||||||
input_ids -> [1, 6]
|
input_ids -> [1, 6] 1:batch_num 6:sequence_length
|
||||||
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
|
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
|
||||||
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
|
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ lm_logits = self.output_layer(hidden_states)
|
||||||
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
|
||||||
|
|
||||||
probs = softmax(lm_logits) -> [1, 65024]
|
probs = softmax(lm_logits) -> [1, 65024]
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1]
|
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
|
||||||
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7]
|
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
|
||||||
|
|
||||||
response = tokenizer.decode(outputs)
|
response = tokenizer.decode(outputs)
|
|
@ -676,11 +676,9 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
eos_token_id = [eos_token_id]
|
eos_token_id = [eos_token_id]
|
||||||
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
|
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
|
||||||
|
|
||||||
unfinished_sequences = torch.ones(
|
isFinished = torch.zeros(
|
||||||
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
input_ids.shape[0], dtype=torch.long, device=input_ids.device
|
||||||
)
|
)
|
||||||
|
|
||||||
this_peer_finished = False # used by synced_gpus only
|
|
||||||
while True:
|
while True:
|
||||||
input_ids_in = input_ids
|
input_ids_in = input_ids
|
||||||
batch_size, seq_length = input_ids_in.shape
|
batch_size, seq_length = input_ids_in.shape
|
||||||
|
@ -699,21 +697,13 @@ class ChatGLMForConditionalGeneration(nn.Module):
|
||||||
probs = nn.functional.softmax(next_token_logits, dim=-1)
|
probs = nn.functional.softmax(next_token_logits, dim=-1)
|
||||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||||
|
|
||||||
# finished sentences should have their next token be a padding token
|
# finished sentences should add a padding token to next
|
||||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
|
pad_token = pad_token_id * isFinished
|
||||||
1 - unfinished_sequences
|
next_tokens = next_tokens * (1 - isFinished) + pad_token
|
||||||
)
|
|
||||||
|
|
||||||
# update generated ids, model inputs, and length for next step
|
|
||||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||||
|
|
||||||
# if eos_token was found in one sentence, set sentence to finished
|
isFinished = isFinished | next_tokens.eq(eos_token_id_tensor)
|
||||||
unfinished_sequences = unfinished_sequences.mul(
|
if isFinished.min() == 1: # all batch is finish
|
||||||
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
|
|
||||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
|
||||||
.prod(dim=0)
|
|
||||||
)
|
|
||||||
if unfinished_sequences.max() == 0:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
x = torch.tensor([[1, 2], [3, 4]])
|
||||||
|
|
||||||
|
print(x)
|
||||||
|
print("x.tile((2)) -> ", x.tile((2)).shape)
|
||||||
|
print(x.tile((2)))
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("x.tile((2, 1)) -> ", x.tile((2, 1)).shape)
|
||||||
|
print(x.tile((2, 1)))
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("x.tile((2, 1, 2)) -> ", x.tile((2, 1, 2)).shape)
|
||||||
|
print(x.tile((2, 1, 2)))
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("x.tile((2, 1, 1)) -> ", x.tile((2, 1, 1)).shape)
|
||||||
|
print(x.tile((2, 1, 1)))
|
||||||
|
|
||||||
|
print()
|
||||||
|
y = torch.tensor([[2, 1], [3, 4]])
|
||||||
|
print(y.ne(x))
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(x.prod(1))
|
||||||
|
print(x.prod(0))
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(x.unsqueeze(1).shape)
|
||||||
|
print(x.unsqueeze(1).squeeze(1).shape)
|
Loading…
Reference in New Issue