Refien sample code.

This commit is contained in:
Colin 2023-12-22 18:57:16 +08:00
parent 10268c4414
commit 72787b9268
3 changed files with 40 additions and 19 deletions

View File

@ -5,7 +5,7 @@
input_ids = tokenizer.build_chat_input(query, history=history, role=role)
input_ids -> [1, 6]
input_ids -> [1, 6] 1:batch_num 6:sequence_length
inputs_embeds -> [6, 1, 4096] 4096:hidden_size
rotary_pos_emb -> [6, 1, 32, 2] 32:pos的编码维度 2:cos+sin
@ -17,7 +17,7 @@ lm_logits = self.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous() -> [1, 1, 65024]
probs = softmax(lm_logits) -> [1, 65024]
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1]
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7]
next_tokens = torch.multinomial(probs, num_samples=1) 采样 -> [1] 1:batch_num
input_ids = torch.cat([input_ids, next_tokens) -> [1, 7] 1:batch_num
response = tokenizer.decode(outputs)

View File

@ -676,11 +676,9 @@ class ChatGLMForConditionalGeneration(nn.Module):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
unfinished_sequences = torch.ones(
isFinished = torch.zeros(
input_ids.shape[0], dtype=torch.long, device=input_ids.device
)
this_peer_finished = False # used by synced_gpus only
while True:
input_ids_in = input_ids
batch_size, seq_length = input_ids_in.shape
@ -699,21 +697,13 @@ class ChatGLMForConditionalGeneration(nn.Module):
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
# update generated ids, model inputs, and length for next step
# finished sentences should add a padding token to next
pad_token = pad_token_id * isFinished
next_tokens = next_tokens * (1 - isFinished) + pad_token
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# if eos_token was found in one sentence, set sentence to finished
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
if unfinished_sequences.max() == 0:
isFinished = isFinished | next_tokens.eq(eos_token_id_tensor)
if isFinished.min() == 1: # all batch is finish
break
return input_ids

31
tensor.py Normal file
View File

@ -0,0 +1,31 @@
import torch
x = torch.tensor([[1, 2], [3, 4]])
print(x)
print("x.tile((2)) -> ", x.tile((2)).shape)
print(x.tile((2)))
print()
print("x.tile((2, 1)) -> ", x.tile((2, 1)).shape)
print(x.tile((2, 1)))
print()
print("x.tile((2, 1, 2)) -> ", x.tile((2, 1, 2)).shape)
print(x.tile((2, 1, 2)))
print()
print("x.tile((2, 1, 1)) -> ", x.tile((2, 1, 1)).shape)
print(x.tile((2, 1, 1)))
print()
y = torch.tensor([[2, 1], [3, 4]])
print(y.ne(x))
print()
print(x.prod(1))
print(x.prod(0))
print()
print(x.unsqueeze(1).shape)
print(x.unsqueeze(1).squeeze(1).shape)